Coverage for src/couchers/servicers/threads.py: 93%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-07-03 14:11 +0000

1import logging 

2 

3import grpc 

4import sqlalchemy.exc 

5from sqlalchemy.sql import func, select 

6 

7from couchers import errors 

8from couchers.db import session_scope 

9from couchers.jobs.enqueue import queue_job 

10from couchers.models import Comment, Discussion, Event, EventOccurrence, Reply, Thread, User 

11from couchers.notifications.notify import notify 

12from couchers.servicers.api import user_model_to_pb 

13from couchers.servicers.blocking import is_not_visible 

14from couchers.sql import couchers_select as select 

15from couchers.utils import Timestamp_from_datetime, make_user_context 

16from proto import notification_data_pb2, threads_pb2, threads_pb2_grpc 

17from proto.internal import jobs_pb2 

18 

19logger = logging.getLogger(__name__) 

20 

21# Since the API exposes a single ID space regardless of nesting level, 

22# we construct the API id by appending the nesting level to the 

23# database ID. 

24 

25 

26def pack_thread_id(database_id: int, depth: int) -> int: 

27 return database_id * 10 + depth 

28 

29 

30def unpack_thread_id(thread_id: int) -> (int, int): 

31 """Returns (database_id, depth) tuple.""" 

32 return divmod(thread_id, 10) 

33 

34 

35def total_num_responses(session, database_id): 

36 """Return the total number of comments and replies to the thread with 

37 database id database_id. 

38 """ 

39 return ( 

40 session.execute(select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)).scalar_one() 

41 + session.execute( 

42 select(func.count()) 

43 .select_from(Reply) 

44 .join(Comment, Comment.id == Reply.comment_id) 

45 .where(Comment.thread_id == database_id) 

46 ).scalar_one() 

47 ) 

48 

49 

50def thread_to_pb(session, database_id): 

51 return threads_pb2.Thread( 

52 thread_id=pack_thread_id(database_id, 0), 

53 num_responses=total_num_responses(session, database_id), 

54 ) 

55 

56 

57def generate_reply_notifications(payload: jobs_pb2.GenerateReplyNotificationsPayload): 

58 from couchers.servicers.discussions import discussion_to_pb 

59 from couchers.servicers.events import event_to_pb 

60 

61 with session_scope() as session: 

62 database_id, depth = unpack_thread_id(payload.thread_id) 

63 if depth == 1: 

64 # this is a top-level Comment on a Thread attached to event, discussion, etc 

65 comment = session.execute(select(Comment).where(Comment.id == database_id)).scalar_one() 

66 thread = session.execute(select(Thread).where(Thread.id == comment.thread_id)).scalar_one() 

67 author_user = session.execute(select(User).where(User.id == comment.author_user_id)).scalar_one() 

68 # reply object for notif 

69 reply = threads_pb2.Reply( 

70 thread_id=payload.thread_id, 

71 content=comment.content, 

72 author_user_id=comment.author_user_id, 

73 created_time=Timestamp_from_datetime(comment.created), 

74 num_replies=0, 

75 ) 

76 # figure out if the thread is related to an event or discussion 

77 event = session.execute(select(Event).where(Event.thread_id == thread.id)).scalar_one_or_none() 

78 discussion = session.execute( 

79 select(Discussion).where(Discussion.thread_id == thread.id) 

80 ).scalar_one_or_none() 

81 if event: 

82 # thread is an event thread 

83 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

84 subscribed_user_ids = [user.id for user in event.subscribers] 

85 attending_user_ids = [user.user_id for user in occurrence.attendances] 

86 

87 for user_id in set(subscribed_user_ids + attending_user_ids): 

88 if is_not_visible(session, user_id, comment.author_user_id): 

89 continue 

90 if user_id == comment.author_user_id: 

91 continue 

92 context = make_user_context(user_id=user_id) 

93 notify( 

94 session, 

95 user_id=user_id, 

96 topic_action="event:comment", 

97 key=occurrence.id, 

98 data=notification_data_pb2.EventComment( 

99 reply=reply, 

100 event=event_to_pb(session, occurrence, context), 

101 author=user_model_to_pb(author_user, session, context), 

102 ), 

103 ) 

104 elif discussion: 

105 # community discussion thread 

106 cluster = discussion.owner_cluster 

107 

108 if not cluster.is_official_cluster: 

109 raise NotImplementedError("Shouldn't have discussions under groups, only communities") 

110 

111 for user_id in [discussion.creator_user_id]: 

112 if is_not_visible(session, user_id, comment.author_user_id): 

113 continue 

114 if user_id == comment.author_user_id: 

115 continue 

116 

117 context = make_user_context(user_id=user_id) 

118 notify( 

119 session, 

120 user_id=user_id, 

121 topic_action="discussion:comment", 

122 key=discussion.id, 

123 data=notification_data_pb2.DiscussionComment( 

124 reply=reply, 

125 discussion=discussion_to_pb(session, discussion, context), 

126 author=user_model_to_pb(author_user, session, context), 

127 ), 

128 ) 

129 else: 

130 raise NotImplementedError("I can only do event and discussion threads for now") 

131 elif depth == 2: 

132 # this is a second-level reply to a comment 

133 reply = session.execute(select(Reply).where(Reply.id == database_id)).scalar_one() 

134 # the comment we're replying to 

135 parent_comment = session.execute(select(Comment).where(Comment.id == reply.comment_id)).scalar_one() 

136 context = make_user_context(user_id=reply.author_user_id) 

137 thread_replies_author_user_ids = ( 

138 session.execute( 

139 select(Reply.author_user_id) 

140 .where_users_column_visible(context, Reply.author_user_id) 

141 .where(Reply.comment_id == parent_comment.id) 

142 ) 

143 .scalars() 

144 .all() 

145 ) 

146 thread_user_ids = set(thread_replies_author_user_ids) 

147 if not is_not_visible(session, parent_comment.author_user_id, reply.author_user_id): 

148 thread_user_ids.add(parent_comment.author_user_id) 

149 

150 author_user = session.execute(select(User).where(User.id == reply.author_user_id)).scalar_one() 

151 

152 user_ids_to_notify = set(thread_user_ids) - {reply.author_user_id} 

153 

154 reply = threads_pb2.Reply( 

155 thread_id=payload.thread_id, 

156 content=reply.content, 

157 author_user_id=reply.author_user_id, 

158 created_time=Timestamp_from_datetime(reply.created), 

159 num_replies=0, 

160 ) 

161 

162 event = session.execute( 

163 select(Event).where(Event.thread_id == parent_comment.thread_id) 

164 ).scalar_one_or_none() 

165 discussion = session.execute( 

166 select(Discussion).where(Discussion.thread_id == parent_comment.thread_id) 

167 ).scalar_one_or_none() 

168 if event: 

169 # thread is an event thread 

170 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one() 

171 for user_id in user_ids_to_notify: 

172 context = make_user_context(user_id=user_id) 

173 notify( 

174 session, 

175 user_id=user_id, 

176 topic_action="thread:reply", 

177 key=occurrence.id, 

178 data=notification_data_pb2.ThreadReply( 

179 reply=reply, 

180 event=event_to_pb(session, occurrence, context), 

181 author=user_model_to_pb(author_user, session, context), 

182 ), 

183 ) 

184 elif discussion: 

185 # community discussion thread 

186 for user_id in user_ids_to_notify: 

187 context = make_user_context(user_id=user_id) 

188 notify( 

189 session, 

190 user_id=user_id, 

191 topic_action="thread:reply", 

192 key=discussion.id, 

193 data=notification_data_pb2.ThreadReply( 

194 reply=reply, 

195 discussion=discussion_to_pb(session, discussion, context), 

196 author=user_model_to_pb(author_user, session, context), 

197 ), 

198 ) 

199 else: 

200 raise NotImplementedError("I can only do event and discussion threads for now") 

201 else: 

202 raise Exception("Unknown depth") 

203 

204 

205class Threads(threads_pb2_grpc.ThreadsServicer): 

206 def GetThread(self, request, context, session): 

207 database_id, depth = unpack_thread_id(request.thread_id) 

208 page_size = request.page_size if 0 < request.page_size < 100000 else 1000 

209 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50 

210 

211 if depth == 0: 

212 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 

213 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

214 

215 res = session.execute( 

216 select(Comment, func.count(Reply.id)) 

217 .outerjoin(Reply, Reply.comment_id == Comment.id) 

218 .where(Comment.thread_id == database_id) 

219 .where(Comment.id < page_start) 

220 .group_by(Comment.id) 

221 .order_by(Comment.created.desc()) 

222 .limit(page_size + 1) 

223 ).all() 

224 replies = [ 

225 threads_pb2.Reply( 

226 thread_id=pack_thread_id(r.id, 1), 

227 content=r.content, 

228 author_user_id=r.author_user_id, 

229 created_time=Timestamp_from_datetime(r.created), 

230 num_replies=n, 

231 ) 

232 for r, n in res[:page_size] 

233 ] 

234 

235 elif depth == 1: 

236 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none(): 

237 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

238 

239 res = ( 

240 session.execute( 

241 select(Reply) 

242 .where(Reply.comment_id == database_id) 

243 .where(Reply.id < page_start) 

244 .order_by(Reply.created.desc()) 

245 .limit(page_size + 1) 

246 ) 

247 .scalars() 

248 .all() 

249 ) 

250 replies = [ 

251 threads_pb2.Reply( 

252 thread_id=pack_thread_id(r.id, 2), 

253 content=r.content, 

254 author_user_id=r.author_user_id, 

255 created_time=Timestamp_from_datetime(r.created), 

256 num_replies=0, 

257 ) 

258 for r in res[:page_size] 

259 ] 

260 

261 else: 

262 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

263 

264 if len(res) > page_size: 

265 # There's more! 

266 next_page_token = str(replies[-1].thread_id) 

267 else: 

268 next_page_token = "" 

269 

270 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token) 

271 

272 def PostReply(self, request, context, session): 

273 content = request.content.strip() 

274 

275 if content == "": 

276 context.abort(grpc.StatusCode.INVALID_ARGUMENT, errors.INVALID_COMMENT) 

277 

278 database_id, depth = unpack_thread_id(request.thread_id) 

279 if depth == 0: 

280 object_to_add = Comment(thread_id=database_id, author_user_id=context.user_id, content=content) 

281 elif depth == 1: 

282 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=content) 

283 else: 

284 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

285 session.add(object_to_add) 

286 try: 

287 session.flush() 

288 except sqlalchemy.exc.IntegrityError: 

289 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND) 

290 

291 thread_id = pack_thread_id(object_to_add.id, depth + 1) 

292 

293 queue_job( 

294 session, 

295 job_type="generate_reply_notifications", 

296 payload=jobs_pb2.GenerateReplyNotificationsPayload( 

297 thread_id=thread_id, 

298 ), 

299 ) 

300 

301 return threads_pb2.PostReplyRes(thread_id=thread_id)