Coverage for src/couchers/servicers/threads.py: 98%

55 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-07-22 17:19 +0000

1import logging 

2 

3import grpc 

4import sqlalchemy.exc 

5from sqlalchemy.sql import func 

6 

7from couchers import errors 

8from couchers.db import session_scope 

9from couchers.models import Comment, Reply, Thread 

10from couchers.sql import couchers_select as select 

11from couchers.utils import Timestamp_from_datetime 

12from proto import threads_pb2, threads_pb2_grpc 

13 

14logger = logging.getLogger(__name__) 

15 

16# Since the API exposes a single ID space regardless of nesting level, 

17# we construct the API id by appending the nesting level to the 

18# database ID. 

19 

20 

21def pack_thread_id(database_id: int, depth: int) -> int: 

22 return database_id * 10 + depth 

23 

24 

25def unpack_thread_id(thread_id: int) -> (int, int): 

26 """Returns (database_id, depth) tuple.""" 

27 return divmod(thread_id, 10) 

28 

29 

30def total_num_responses(database_id): 

31 """Return the total number of comments and replies to the thread with 

32 database id database_id. 

33 """ 

34 with session_scope() as session: 

35 return ( 

36 session.execute( 

37 select(func.count()).select_from(Comment).where(Comment.thread_id == database_id) 

38 ).scalar_one() 

39 + session.execute( 

40 select(func.count()) 

41 .select_from(Reply) 

42 .join(Comment, Comment.id == Reply.comment_id) 

43 .where(Comment.thread_id == database_id) 

44 ).scalar_one() 

45 ) 

46 

47 

48def thread_to_pb(database_id): 

49 return threads_pb2.Thread( 

50 thread_id=pack_thread_id(database_id, 0), 

51 num_responses=total_num_responses(database_id), 

52 ) 

53 

54 

55class Threads(threads_pb2_grpc.ThreadsServicer): 

56 def GetThread(self, request, context): 

57 database_id, depth = unpack_thread_id(request.thread_id) 

58 page_size = request.page_size if 0 < request.page_size < 100000 else 1000 

59 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50 

60 

61 with session_scope() as session: 

62 if depth == 0: 

63 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 

64 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

65 

66 res = session.execute( 

67 select(Comment, func.count(Reply.id)) 

68 .outerjoin(Reply, Reply.comment_id == Comment.id) 

69 .where(Comment.thread_id == database_id) 

70 .where(Comment.id < page_start) 

71 .group_by(Comment.id) 

72 .order_by(Comment.created.desc()) 

73 .limit(page_size + 1) 

74 ).all() 

75 replies = [ 

76 threads_pb2.Reply( 

77 thread_id=pack_thread_id(r.id, 1), 

78 content=r.content, 

79 author_user_id=r.author_user_id, 

80 created_time=Timestamp_from_datetime(r.created), 

81 num_replies=n, 

82 ) 

83 for r, n in res[:page_size] 

84 ] 

85 

86 elif depth == 1: 

87 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none(): 

88 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

89 

90 res = ( 

91 session.execute( 

92 select(Reply) 

93 .where(Reply.comment_id == database_id) 

94 .where(Reply.id < page_start) 

95 .order_by(Reply.created.desc()) 

96 .limit(page_size + 1) 

97 ) 

98 .scalars() 

99 .all() 

100 ) 

101 replies = [ 

102 threads_pb2.Reply( 

103 thread_id=pack_thread_id(r.id, 2), 

104 content=r.content, 

105 author_user_id=r.author_user_id, 

106 created_time=Timestamp_from_datetime(r.created), 

107 num_replies=0, 

108 ) 

109 for r in res[:page_size] 

110 ] 

111 

112 else: 

113 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

114 

115 if len(res) > page_size: 

116 # There's more! 

117 next_page_token = str(replies[-1].thread_id) 

118 else: 

119 next_page_token = "" 

120 

121 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token) 

122 

123 def PostReply(self, request, context): 

124 with session_scope() as session: 

125 database_id, depth = unpack_thread_id(request.thread_id) 

126 if depth == 0: 

127 object_to_add = Comment(thread_id=database_id, author_user_id=context.user_id, content=request.content) 

128 elif depth == 1: 

129 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=request.content) 

130 else: 

131 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

132 session.add(object_to_add) 

133 try: 

134 session.flush() 

135 except sqlalchemy.exc.IntegrityError: 

136 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

137 

138 return threads_pb2.PostReplyRes(thread_id=pack_thread_id(object_to_add.id, depth + 1))