Coverage for src/couchers/servicers/threads.py: 98%

51 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-10-15 13:03 +0000

1import logging 

2 

3import grpc 

4import sqlalchemy.exc 

5from sqlalchemy.sql import func 

6 

7from couchers import errors 

8from couchers.models import Comment, Reply, Thread 

9from couchers.sql import couchers_select as select 

10from couchers.utils import Timestamp_from_datetime 

11from proto import threads_pb2, threads_pb2_grpc 

12 

13logger = logging.getLogger(__name__) 

14 

15# Since the API exposes a single ID space regardless of nesting level, 

16# we construct the API id by appending the nesting level to the 

17# database ID. 

18 

19 

20def pack_thread_id(database_id: int, depth: int) -> int: 

21 return database_id * 10 + depth 

22 

23 

24def unpack_thread_id(thread_id: int) -> (int, int): 

25 """Returns (database_id, depth) tuple.""" 

26 return divmod(thread_id, 10) 

27 

28 

29def total_num_responses(session, database_id): 

30 """Return the total number of comments and replies to the thread with 

31 database id database_id. 

32 """ 

33 return ( 

34 session.execute(select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)).scalar_one() 

35 + session.execute( 

36 select(func.count()) 

37 .select_from(Reply) 

38 .join(Comment, Comment.id == Reply.comment_id) 

39 .where(Comment.thread_id == database_id) 

40 ).scalar_one() 

41 ) 

42 

43 

44def thread_to_pb(session, database_id): 

45 return threads_pb2.Thread( 

46 thread_id=pack_thread_id(database_id, 0), 

47 num_responses=total_num_responses(session, database_id), 

48 ) 

49 

50 

51class Threads(threads_pb2_grpc.ThreadsServicer): 

52 def GetThread(self, request, context, session): 

53 database_id, depth = unpack_thread_id(request.thread_id) 

54 page_size = request.page_size if 0 < request.page_size < 100000 else 1000 

55 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50 

56 

57 if depth == 0: 

58 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 

59 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

60 

61 res = session.execute( 

62 select(Comment, func.count(Reply.id)) 

63 .outerjoin(Reply, Reply.comment_id == Comment.id) 

64 .where(Comment.thread_id == database_id) 

65 .where(Comment.id < page_start) 

66 .group_by(Comment.id) 

67 .order_by(Comment.created.desc()) 

68 .limit(page_size + 1) 

69 ).all() 

70 replies = [ 

71 threads_pb2.Reply( 

72 thread_id=pack_thread_id(r.id, 1), 

73 content=r.content, 

74 author_user_id=r.author_user_id, 

75 created_time=Timestamp_from_datetime(r.created), 

76 num_replies=n, 

77 ) 

78 for r, n in res[:page_size] 

79 ] 

80 

81 elif depth == 1: 

82 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none(): 

83 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

84 

85 res = ( 

86 session.execute( 

87 select(Reply) 

88 .where(Reply.comment_id == database_id) 

89 .where(Reply.id < page_start) 

90 .order_by(Reply.created.desc()) 

91 .limit(page_size + 1) 

92 ) 

93 .scalars() 

94 .all() 

95 ) 

96 replies = [ 

97 threads_pb2.Reply( 

98 thread_id=pack_thread_id(r.id, 2), 

99 content=r.content, 

100 author_user_id=r.author_user_id, 

101 created_time=Timestamp_from_datetime(r.created), 

102 num_replies=0, 

103 ) 

104 for r in res[:page_size] 

105 ] 

106 

107 else: 

108 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

109 

110 if len(res) > page_size: 

111 # There's more! 

112 next_page_token = str(replies[-1].thread_id) 

113 else: 

114 next_page_token = "" 

115 

116 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token) 

117 

118 def PostReply(self, request, context, session): 

119 database_id, depth = unpack_thread_id(request.thread_id) 

120 if depth == 0: 

121 object_to_add = Comment(thread_id=database_id, author_user_id=context.user_id, content=request.content) 

122 elif depth == 1: 

123 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=request.content) 

124 else: 

125 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

126 session.add(object_to_add) 

127 try: 

128 session.flush() 

129 except sqlalchemy.exc.IntegrityError: 

130 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

131 

132 return threads_pb2.PostReplyRes(thread_id=pack_thread_id(object_to_add.id, depth + 1))