Coverage for app / backend / src / couchers / servicers / threads.py: 89%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 09:44 +0000

1import logging 

2 

3import grpc 

4import sqlalchemy.exc 

5from sqlalchemy import select 

6from sqlalchemy.orm import Session 

7from sqlalchemy.sql import func 

8 

9from couchers.context import CouchersContext, make_background_user_context 

10from couchers.db import session_scope 

11from couchers.jobs.enqueue import queue_job 

12from couchers.models import Comment, Discussion, Event, EventOccurrence, Reply, Thread, User 

13from couchers.models.notifications import NotificationTopicAction 

14from couchers.notifications.notify import notify 

15from couchers.proto import notification_data_pb2, threads_pb2, threads_pb2_grpc 

16from couchers.proto.internal import jobs_pb2 

17from couchers.servicers.api import user_model_to_pb 

18from couchers.servicers.blocking import is_not_visible 

19from couchers.sql import where_users_column_visible 

20from couchers.utils import Timestamp_from_datetime 

21 

22logger = logging.getLogger(__name__) 

23 

24# Since the API exposes a single ID space regardless of nesting level, 

25# we construct the API id by appending the nesting level to the 

26# database ID. 

27 

28 

29def pack_thread_id(database_id: int, depth: int) -> int: 

30 return database_id * 10 + depth 

31 

32 

33def unpack_thread_id(thread_id: int) -> tuple[int, int]: 

34 """Returns (database_id, depth) tuple.""" 

35 return divmod(thread_id, 10) 

36 

37 

38def total_num_responses(session: Session, database_id: int) -> int: 

39 """Return the total number of comments and replies to the thread with 

40 database id database_id. 

41 """ 

42 comments = select(func.count()).select_from(Comment).where(Comment.thread_id == database_id) 

43 replies = ( 

44 select(func.count()) 

45 .select_from(Reply) 

46 .join(Comment, Comment.id == Reply.comment_id) 

47 .where(Comment.thread_id == database_id) 

48 ) 

49 return session.execute(comments).scalar_one() + session.execute(replies).scalar_one() 

50 

51 

52def thread_to_pb(session: Session, database_id: int) -> threads_pb2.Thread: 

53 return threads_pb2.Thread( 

54 thread_id=pack_thread_id(database_id, 0), 

55 num_responses=total_num_responses(session, database_id), 

56 ) 

57 

58 

59def generate_reply_notifications(payload: jobs_pb2.GenerateReplyNotificationsPayload) -> None: 

60 # Import here to avoid circular dependency 

61 from couchers.servicers.discussions import discussion_to_pb # noqa: PLC0415 

62 from couchers.servicers.events import event_to_pb # noqa: PLC0415 

63 

64 with session_scope() as session: 

65 database_id, depth = unpack_thread_id(payload.thread_id) 

66 if depth == 1: 

67 # this is a top-level Comment on a Thread attached to event, discussion, etc 

68 comment = session.execute(select(Comment).where(Comment.id == database_id)).scalar_one() 

69 thread = session.execute(select(Thread).where(Thread.id == comment.thread_id)).scalar_one() 

70 author_user = session.execute(select(User).where(User.id == comment.author_user_id)).scalar_one() 

71 # reply object for notif 

72 reply = threads_pb2.Reply( 

73 thread_id=payload.thread_id, 

74 content=comment.content, 

75 author_user_id=comment.author_user_id, 

76 created_time=Timestamp_from_datetime(comment.created), 

77 num_replies=0, 

78 ) 

79 # figure out if the thread is related to an event or discussion 

80 event = session.execute(select(Event).where(Event.thread_id == thread.id)).scalar_one_or_none() 

81 discussion = session.execute( 

82 select(Discussion).where(Discussion.thread_id == thread.id) 

83 ).scalar_one_or_none() 

84 if event: 

85 # thread is an event thread 

86 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

87 subscribed_user_ids = [user.id for user in event.subscribers] 

88 attending_user_ids = [user.user_id for user in occurrence.attendances] 

89 

90 for user_id in set(subscribed_user_ids + attending_user_ids): 

91 if is_not_visible(session, user_id, comment.author_user_id): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 continue 

93 if user_id == comment.author_user_id: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 continue 

95 context = make_background_user_context(user_id=user_id) 

96 notify( 

97 session, 

98 user_id=user_id, 

99 topic_action=NotificationTopicAction.event__comment, 

100 key=str(occurrence.id), 

101 data=notification_data_pb2.EventComment( 

102 reply=reply, 

103 event=event_to_pb(session, occurrence, context), 

104 author=user_model_to_pb(author_user, session, context), 

105 ), 

106 moderation_state_id=occurrence.moderation_state_id, 

107 ) 

108 elif discussion: 108 ↛ 134line 108 didn't jump to line 134 because the condition on line 108 was always true

109 # community discussion thread 

110 cluster = discussion.owner_cluster 

111 

112 if not cluster.is_official_cluster: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 raise NotImplementedError("Shouldn't have discussions under groups, only communities") 

114 

115 for user_id in [discussion.creator_user_id]: 

116 if is_not_visible(session, user_id, comment.author_user_id): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 continue 

118 if user_id == comment.author_user_id: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 continue 

120 

121 context = make_background_user_context(user_id=user_id) 

122 notify( 

123 session, 

124 user_id=user_id, 

125 topic_action=NotificationTopicAction.discussion__comment, 

126 key=str(discussion.id), 

127 data=notification_data_pb2.DiscussionComment( 

128 reply=reply, 

129 discussion=discussion_to_pb(session, discussion, context), 

130 author=user_model_to_pb(author_user, session, context), 

131 ), 

132 ) 

133 else: 

134 raise NotImplementedError("I can only do event and discussion threads for now") 

135 elif depth == 2: 135 ↛ 209line 135 didn't jump to line 209 because the condition on line 135 was always true

136 # this is a second-level reply to a comment 

137 db_reply = session.execute(select(Reply).where(Reply.id == database_id)).scalar_one() 

138 # the comment we're replying to 

139 parent_comment = session.execute(select(Comment).where(Comment.id == db_reply.comment_id)).scalar_one() 

140 context = make_background_user_context(user_id=db_reply.author_user_id) 

141 thread_replies_author_user_ids = ( 

142 session.execute( 

143 where_users_column_visible( 

144 select(Reply.author_user_id).where(Reply.comment_id == parent_comment.id), 

145 context, 

146 Reply.author_user_id, 

147 ) 

148 ) 

149 .scalars() 

150 .all() 

151 ) 

152 thread_user_ids = set(thread_replies_author_user_ids) 

153 if not is_not_visible(session, parent_comment.author_user_id, db_reply.author_user_id): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true

154 thread_user_ids.add(parent_comment.author_user_id) 

155 

156 author_user = session.execute(select(User).where(User.id == db_reply.author_user_id)).scalar_one() 

157 

158 user_ids_to_notify = set(thread_user_ids) - {db_reply.author_user_id} 

159 

160 reply = threads_pb2.Reply( 

161 thread_id=payload.thread_id, 

162 content=db_reply.content, 

163 author_user_id=db_reply.author_user_id, 

164 created_time=Timestamp_from_datetime(db_reply.created), 

165 num_replies=0, 

166 ) 

167 

168 event = session.execute( 

169 select(Event).where(Event.thread_id == parent_comment.thread_id) 

170 ).scalar_one_or_none() 

171 discussion = session.execute( 

172 select(Discussion).where(Discussion.thread_id == parent_comment.thread_id) 

173 ).scalar_one_or_none() 

174 if event: 

175 # thread is an event thread 

176 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

177 for user_id in user_ids_to_notify: 

178 context = make_background_user_context(user_id=user_id) 

179 notify( 

180 session, 

181 user_id=user_id, 

182 topic_action=NotificationTopicAction.thread__reply, 

183 key=str(occurrence.id), 

184 data=notification_data_pb2.ThreadReply( 

185 reply=reply, 

186 event=event_to_pb(session, occurrence, context), 

187 author=user_model_to_pb(author_user, session, context), 

188 ), 

189 moderation_state_id=occurrence.moderation_state_id, 

190 ) 

191 elif discussion: 191 ↛ 207line 191 didn't jump to line 207 because the condition on line 191 was always true

192 # community discussion thread 

193 for user_id in user_ids_to_notify: 

194 context = make_background_user_context(user_id=user_id) 

195 notify( 

196 session, 

197 user_id=user_id, 

198 topic_action=NotificationTopicAction.thread__reply, 

199 key=str(discussion.id), 

200 data=notification_data_pb2.ThreadReply( 

201 reply=reply, 

202 discussion=discussion_to_pb(session, discussion, context), 

203 author=user_model_to_pb(author_user, session, context), 

204 ), 

205 ) 

206 else: 

207 raise NotImplementedError("I can only do event and discussion threads for now") 

208 else: 

209 raise Exception("Unknown depth") 

210 

211 

212class Threads(threads_pb2_grpc.ThreadsServicer): 

213 def GetThread( 

214 self, request: threads_pb2.GetThreadReq, context: CouchersContext, session: Session 

215 ) -> threads_pb2.GetThreadRes: 

216 database_id, depth = unpack_thread_id(request.thread_id) 

217 page_size = request.page_size if 0 < request.page_size < 100000 else 1000 

218 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50 

219 

220 if depth == 0: 

221 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found") 

223 

224 res = session.execute( 

225 select(Comment, func.count(Reply.id)) 

226 .outerjoin(Reply, Reply.comment_id == Comment.id) 

227 .where(Comment.thread_id == database_id) 

228 .where(Comment.id < page_start) 

229 .group_by(Comment.id) 

230 .order_by(Comment.created.desc()) 

231 .limit(page_size + 1) 

232 ).all() 

233 replies = [ 

234 threads_pb2.Reply( 

235 thread_id=pack_thread_id(r.id, 1), 

236 content=r.content, 

237 author_user_id=r.author_user_id, 

238 created_time=Timestamp_from_datetime(r.created), 

239 num_replies=n, 

240 ) 

241 for r, n in res[:page_size] 

242 ] 

243 

244 elif depth == 1: 

245 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none(): 

246 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found") 

247 

248 res = ( 

249 session.execute( # type: ignore[assignment] 

250 select(Reply) 

251 .where(Reply.comment_id == database_id) 

252 .where(Reply.id < page_start) 

253 .order_by(Reply.created.desc()) 

254 .limit(page_size + 1) 

255 ) 

256 .scalars() 

257 .all() 

258 ) 

259 replies = [ 

260 threads_pb2.Reply( 

261 thread_id=pack_thread_id(r.id, 2), 

262 content=r.content, 

263 author_user_id=r.author_user_id, 

264 created_time=Timestamp_from_datetime(r.created), 

265 num_replies=0, 

266 ) 

267 for r in res[:page_size] 

268 ] 

269 

270 else: 

271 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found") 

272 

273 if len(res) > page_size: 

274 # There's more! 

275 next_page_token = str(replies[-1].thread_id) 

276 else: 

277 next_page_token = "" 

278 

279 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token) 

280 

281 def PostReply( 

282 self, request: threads_pb2.PostReplyReq, context: CouchersContext, session: Session 

283 ) -> threads_pb2.PostReplyRes: 

284 content = request.content.strip() 

285 

286 if content == "": 

287 context.abort_with_error_code(grpc.StatusCode.INVALID_ARGUMENT, "invalid_comment") 

288 

289 database_id, depth = unpack_thread_id(request.thread_id) 

290 if depth == 0: 

291 object_to_add: Comment | Reply = Comment( 

292 thread_id=database_id, author_user_id=context.user_id, content=content 

293 ) 

294 elif depth == 1: 

295 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=content) 

296 else: 

297 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found") 

298 session.add(object_to_add) 

299 try: 

300 session.flush() 

301 except sqlalchemy.exc.IntegrityError: 

302 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found") 

303 

304 thread_id = pack_thread_id(object_to_add.id, depth + 1) 

305 

306 queue_job( 

307 session, 

308 job=generate_reply_notifications, 

309 payload=jobs_pb2.GenerateReplyNotificationsPayload( 

310 thread_id=thread_id, 

311 ), 

312 ) 

313 

314 return threads_pb2.PostReplyRes(thread_id=thread_id)