Coverage for src/couchers/servicers/threads.py: 93%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-09-14 15:31 +0000

1import logging 

2 

3import grpc 

4import sqlalchemy.exc 

5from sqlalchemy.sql import func, select 

6 

7from couchers import errors 

8from couchers.context import make_background_user_context 

9from couchers.db import session_scope 

10from couchers.jobs.enqueue import queue_job 

11from couchers.models import Comment, Discussion, Event, EventOccurrence, Reply, Thread, User 

12from couchers.notifications.notify import notify 

13from couchers.servicers.api import user_model_to_pb 

14from couchers.servicers.blocking import is_not_visible 

15from couchers.sql import couchers_select as select 

16from couchers.utils import Timestamp_from_datetime 

17from proto import notification_data_pb2, threads_pb2, threads_pb2_grpc 

18from proto.internal import jobs_pb2 

19 

20logger = logging.getLogger(__name__) 

21 

22# Since the API exposes a single ID space regardless of nesting level, 

23# we construct the API id by appending the nesting level to the 

24# database ID. 

25 

26 

27def pack_thread_id(database_id: int, depth: int) -> int: 

28 return database_id * 10 + depth 

29 

30 

31def unpack_thread_id(thread_id: int) -> (int, int): 

32 """Returns (database_id, depth) tuple.""" 

33 return divmod(thread_id, 10) 

34 

35 

36def total_num_responses(session, database_id): 

37 """Return the total number of comments and replies to the thread with 

38 database id database_id. 

39 """ 

40 return ( 

41 session.execute(select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)).scalar_one() 

42 + session.execute( 

43 select(func.count()) 

44 .select_from(Reply) 

45 .join(Comment, Comment.id == Reply.comment_id) 

46 .where(Comment.thread_id == database_id) 

47 ).scalar_one() 

48 ) 

49 

50 

51def thread_to_pb(session, database_id): 

52 return threads_pb2.Thread( 

53 thread_id=pack_thread_id(database_id, 0), 

54 num_responses=total_num_responses(session, database_id), 

55 ) 

56 

57 

58def generate_reply_notifications(payload: jobs_pb2.GenerateReplyNotificationsPayload): 

59 from couchers.servicers.discussions import discussion_to_pb 

60 from couchers.servicers.events import event_to_pb 

61 

62 with session_scope() as session: 

63 database_id, depth = unpack_thread_id(payload.thread_id) 

64 if depth == 1: 

65 # this is a top-level Comment on a Thread attached to event, discussion, etc 

66 comment = session.execute(select(Comment).where(Comment.id == database_id)).scalar_one() 

67 thread = session.execute(select(Thread).where(Thread.id == comment.thread_id)).scalar_one() 

68 author_user = session.execute(select(User).where(User.id == comment.author_user_id)).scalar_one() 

69 # reply object for notif 

70 reply = threads_pb2.Reply( 

71 thread_id=payload.thread_id, 

72 content=comment.content, 

73 author_user_id=comment.author_user_id, 

74 created_time=Timestamp_from_datetime(comment.created), 

75 num_replies=0, 

76 ) 

77 # figure out if the thread is related to an event or discussion 

78 event = session.execute(select(Event).where(Event.thread_id == thread.id)).scalar_one_or_none() 

79 discussion = session.execute( 

80 select(Discussion).where(Discussion.thread_id == thread.id) 

81 ).scalar_one_or_none() 

82 if event: 

83 # thread is an event thread 

84 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

85 subscribed_user_ids = [user.id for user in event.subscribers] 

86 attending_user_ids = [user.user_id for user in occurrence.attendances] 

87 

88 for user_id in set(subscribed_user_ids + attending_user_ids): 

89 if is_not_visible(session, user_id, comment.author_user_id): 

90 continue 

91 if user_id == comment.author_user_id: 

92 continue 

93 context = make_background_user_context(user_id=user_id) 

94 notify( 

95 session, 

96 user_id=user_id, 

97 topic_action="event:comment", 

98 key=occurrence.id, 

99 data=notification_data_pb2.EventComment( 

100 reply=reply, 

101 event=event_to_pb(session, occurrence, context), 

102 author=user_model_to_pb(author_user, session, context), 

103 ), 

104 ) 

105 elif discussion: 

106 # community discussion thread 

107 cluster = discussion.owner_cluster 

108 

109 if not cluster.is_official_cluster: 

110 raise NotImplementedError("Shouldn't have discussions under groups, only communities") 

111 

112 for user_id in [discussion.creator_user_id]: 

113 if is_not_visible(session, user_id, comment.author_user_id): 

114 continue 

115 if user_id == comment.author_user_id: 

116 continue 

117 

118 context = make_background_user_context(user_id=user_id) 

119 notify( 

120 session, 

121 user_id=user_id, 

122 topic_action="discussion:comment", 

123 key=discussion.id, 

124 data=notification_data_pb2.DiscussionComment( 

125 reply=reply, 

126 discussion=discussion_to_pb(session, discussion, context), 

127 author=user_model_to_pb(author_user, session, context), 

128 ), 

129 ) 

130 else: 

131 raise NotImplementedError("I can only do event and discussion threads for now") 

132 elif depth == 2: 

133 # this is a second-level reply to a comment 

134 reply = session.execute(select(Reply).where(Reply.id == database_id)).scalar_one() 

135 # the comment we're replying to 

136 parent_comment = session.execute(select(Comment).where(Comment.id == reply.comment_id)).scalar_one() 

137 context = make_background_user_context(user_id=reply.author_user_id) 

138 thread_replies_author_user_ids = ( 

139 session.execute( 

140 select(Reply.author_user_id) 

141 .where_users_column_visible(context, Reply.author_user_id) 

142 .where(Reply.comment_id == parent_comment.id) 

143 ) 

144 .scalars() 

145 .all() 

146 ) 

147 thread_user_ids = set(thread_replies_author_user_ids) 

148 if not is_not_visible(session, parent_comment.author_user_id, reply.author_user_id): 

149 thread_user_ids.add(parent_comment.author_user_id) 

150 

151 author_user = session.execute(select(User).where(User.id == reply.author_user_id)).scalar_one() 

152 

153 user_ids_to_notify = set(thread_user_ids) - {reply.author_user_id} 

154 

155 reply = threads_pb2.Reply( 

156 thread_id=payload.thread_id, 

157 content=reply.content, 

158 author_user_id=reply.author_user_id, 

159 created_time=Timestamp_from_datetime(reply.created), 

160 num_replies=0, 

161 ) 

162 

163 event = session.execute( 

164 select(Event).where(Event.thread_id == parent_comment.thread_id) 

165 ).scalar_one_or_none() 

166 discussion = session.execute( 

167 select(Discussion).where(Discussion.thread_id == parent_comment.thread_id) 

168 ).scalar_one_or_none() 

169 if event: 

170 # thread is an event thread 

171 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

172 for user_id in user_ids_to_notify: 

173 context = make_background_user_context(user_id=user_id) 

174 notify( 

175 session, 

176 user_id=user_id, 

177 topic_action="thread:reply", 

178 key=occurrence.id, 

179 data=notification_data_pb2.ThreadReply( 

180 reply=reply, 

181 event=event_to_pb(session, occurrence, context), 

182 author=user_model_to_pb(author_user, session, context), 

183 ), 

184 ) 

185 elif discussion: 

186 # community discussion thread 

187 for user_id in user_ids_to_notify: 

188 context = make_background_user_context(user_id=user_id) 

189 notify( 

190 session, 

191 user_id=user_id, 

192 topic_action="thread:reply", 

193 key=discussion.id, 

194 data=notification_data_pb2.ThreadReply( 

195 reply=reply, 

196 discussion=discussion_to_pb(session, discussion, context), 

197 author=user_model_to_pb(author_user, session, context), 

198 ), 

199 ) 

200 else: 

201 raise NotImplementedError("I can only do event and discussion threads for now") 

202 else: 

203 raise Exception("Unknown depth") 

204 

205 

206class Threads(threads_pb2_grpc.ThreadsServicer): 

207 def GetThread(self, request, context, session): 

208 database_id, depth = unpack_thread_id(request.thread_id) 

209 page_size = request.page_size if 0 < request.page_size < 100000 else 1000 

210 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50 

211 

212 if depth == 0: 

213 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 

214 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

215 

216 res = session.execute( 

217 select(Comment, func.count(Reply.id)) 

218 .outerjoin(Reply, Reply.comment_id == Comment.id) 

219 .where(Comment.thread_id == database_id) 

220 .where(Comment.id < page_start) 

221 .group_by(Comment.id) 

222 .order_by(Comment.created.desc()) 

223 .limit(page_size + 1) 

224 ).all() 

225 replies = [ 

226 threads_pb2.Reply( 

227 thread_id=pack_thread_id(r.id, 1), 

228 content=r.content, 

229 author_user_id=r.author_user_id, 

230 created_time=Timestamp_from_datetime(r.created), 

231 num_replies=n, 

232 ) 

233 for r, n in res[:page_size] 

234 ] 

235 

236 elif depth == 1: 

237 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none(): 

238 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

239 

240 res = ( 

241 session.execute( 

242 select(Reply) 

243 .where(Reply.comment_id == database_id) 

244 .where(Reply.id < page_start) 

245 .order_by(Reply.created.desc()) 

246 .limit(page_size + 1) 

247 ) 

248 .scalars() 

249 .all() 

250 ) 

251 replies = [ 

252 threads_pb2.Reply( 

253 thread_id=pack_thread_id(r.id, 2), 

254 content=r.content, 

255 author_user_id=r.author_user_id, 

256 created_time=Timestamp_from_datetime(r.created), 

257 num_replies=0, 

258 ) 

259 for r in res[:page_size] 

260 ] 

261 

262 else: 

263 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

264 

265 if len(res) > page_size: 

266 # There's more! 

267 next_page_token = str(replies[-1].thread_id) 

268 else: 

269 next_page_token = "" 

270 

271 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token) 

272 

273 def PostReply(self, request, context, session): 

274 content = request.content.strip() 

275 

276 if content == "": 

277 context.abort(grpc.StatusCode.INVALID_ARGUMENT, errors.INVALID_COMMENT) 

278 

279 database_id, depth = unpack_thread_id(request.thread_id) 

280 if depth == 0: 

281 object_to_add = Comment(thread_id=database_id, author_user_id=context.user_id, content=content) 

282 elif depth == 1: 

283 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=content) 

284 else: 

285 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

286 session.add(object_to_add) 

287 try: 

288 session.flush() 

289 except sqlalchemy.exc.IntegrityError: 

290 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

291 

292 thread_id = pack_thread_id(object_to_add.id, depth + 1) 

293 

294 queue_job( 

295 session, 

296 job_type="generate_reply_notifications", 

297 payload=jobs_pb2.GenerateReplyNotificationsPayload( 

298 thread_id=thread_id, 

299 ), 

300 ) 

301 

302 return threads_pb2.PostReplyRes(thread_id=thread_id)