Coverage for app / backend / src / couchers / servicers / threads.py: 89%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
1import logging
3import grpc
4import sqlalchemy.exc
5from sqlalchemy import select
6from sqlalchemy.orm import Session
7from sqlalchemy.sql import func
9from couchers.context import CouchersContext, make_background_user_context
10from couchers.db import session_scope
11from couchers.jobs.enqueue import queue_job
12from couchers.models import Comment, Discussion, Event, EventOccurrence, Reply, Thread, User
13from couchers.models.notifications import NotificationTopicAction
14from couchers.notifications.notify import notify
15from couchers.proto import notification_data_pb2, threads_pb2, threads_pb2_grpc
16from couchers.proto.internal import jobs_pb2
17from couchers.servicers.api import user_model_to_pb
18from couchers.servicers.blocking import is_not_visible
19from couchers.sql import where_users_column_visible
20from couchers.utils import Timestamp_from_datetime
22logger = logging.getLogger(__name__)
24# Since the API exposes a single ID space regardless of nesting level,
25# we construct the API id by appending the nesting level to the
26# database ID.
29def pack_thread_id(database_id: int, depth: int) -> int:
30 return database_id * 10 + depth
33def unpack_thread_id(thread_id: int) -> tuple[int, int]:
34 """Returns (database_id, depth) tuple."""
35 return divmod(thread_id, 10)
38def total_num_responses(session: Session, database_id: int) -> int:
39 """Return the total number of comments and replies to the thread with
40 database id database_id.
41 """
42 comments = select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)
43 replies = (
44 select(func.count())
45 .select_from(Reply)
46 .join(Comment, Comment.id == Reply.comment_id)
47 .where(Comment.thread_id == database_id)
48 )
49 return session.execute(comments).scalar_one() + session.execute(replies).scalar_one()
52def thread_to_pb(session: Session, database_id: int) -> threads_pb2.Thread:
53 return threads_pb2.Thread(
54 thread_id=pack_thread_id(database_id, 0),
55 num_responses=total_num_responses(session, database_id),
56 )
59def generate_reply_notifications(payload: jobs_pb2.GenerateReplyNotificationsPayload) -> None:
60 # Import here to avoid circular dependency
61 from couchers.servicers.discussions import discussion_to_pb # noqa: PLC0415
62 from couchers.servicers.events import event_to_pb # noqa: PLC0415
64 with session_scope() as session:
65 database_id, depth = unpack_thread_id(payload.thread_id)
66 if depth == 1:
67 # this is a top-level Comment on a Thread attached to event, discussion, etc
68 comment = session.execute(select(Comment).where(Comment.id == database_id)).scalar_one()
69 thread = session.execute(select(Thread).where(Thread.id == comment.thread_id)).scalar_one()
70 author_user = session.execute(select(User).where(User.id == comment.author_user_id)).scalar_one()
71 # reply object for notif
72 reply = threads_pb2.Reply(
73 thread_id=payload.thread_id,
74 content=comment.content,
75 author_user_id=comment.author_user_id,
76 created_time=Timestamp_from_datetime(comment.created),
77 num_replies=0,
78 )
79 # figure out if the thread is related to an event or discussion
80 event = session.execute(select(Event).where(Event.thread_id == thread.id)).scalar_one_or_none()
81 discussion = session.execute(
82 select(Discussion).where(Discussion.thread_id == thread.id)
83 ).scalar_one_or_none()
84 if event:
85 # thread is an event thread
86 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one()
87 subscribed_user_ids = [user.id for user in event.subscribers]
88 attending_user_ids = [user.user_id for user in occurrence.attendances]
90 for user_id in set(subscribed_user_ids + attending_user_ids):
91 if is_not_visible(session, user_id, comment.author_user_id): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 continue
93 if user_id == comment.author_user_id: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 continue
95 context = make_background_user_context(user_id=user_id)
96 notify(
97 session,
98 user_id=user_id,
99 topic_action=NotificationTopicAction.event__comment,
100 key=str(occurrence.id),
101 data=notification_data_pb2.EventComment(
102 reply=reply,
103 event=event_to_pb(session, occurrence, context),
104 author=user_model_to_pb(author_user, session, context),
105 ),
106 moderation_state_id=occurrence.moderation_state_id,
107 )
108 elif discussion: 108 ↛ 134line 108 didn't jump to line 134 because the condition on line 108 was always true
109 # community discussion thread
110 cluster = discussion.owner_cluster
112 if not cluster.is_official_cluster: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 raise NotImplementedError("Shouldn't have discussions under groups, only communities")
115 for user_id in [discussion.creator_user_id]:
116 if is_not_visible(session, user_id, comment.author_user_id): 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 continue
118 if user_id == comment.author_user_id: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 continue
121 context = make_background_user_context(user_id=user_id)
122 notify(
123 session,
124 user_id=user_id,
125 topic_action=NotificationTopicAction.discussion__comment,
126 key=str(discussion.id),
127 data=notification_data_pb2.DiscussionComment(
128 reply=reply,
129 discussion=discussion_to_pb(session, discussion, context),
130 author=user_model_to_pb(author_user, session, context),
131 ),
132 )
133 else:
134 raise NotImplementedError("I can only do event and discussion threads for now")
135 elif depth == 2: 135 ↛ 209line 135 didn't jump to line 209 because the condition on line 135 was always true
136 # this is a second-level reply to a comment
137 db_reply = session.execute(select(Reply).where(Reply.id == database_id)).scalar_one()
138 # the comment we're replying to
139 parent_comment = session.execute(select(Comment).where(Comment.id == db_reply.comment_id)).scalar_one()
140 context = make_background_user_context(user_id=db_reply.author_user_id)
141 thread_replies_author_user_ids = (
142 session.execute(
143 where_users_column_visible(
144 select(Reply.author_user_id).where(Reply.comment_id == parent_comment.id),
145 context,
146 Reply.author_user_id,
147 )
148 )
149 .scalars()
150 .all()
151 )
152 thread_user_ids = set(thread_replies_author_user_ids)
153 if not is_not_visible(session, parent_comment.author_user_id, db_reply.author_user_id): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true
154 thread_user_ids.add(parent_comment.author_user_id)
156 author_user = session.execute(select(User).where(User.id == db_reply.author_user_id)).scalar_one()
158 user_ids_to_notify = set(thread_user_ids) - {db_reply.author_user_id}
160 reply = threads_pb2.Reply(
161 thread_id=payload.thread_id,
162 content=db_reply.content,
163 author_user_id=db_reply.author_user_id,
164 created_time=Timestamp_from_datetime(db_reply.created),
165 num_replies=0,
166 )
168 event = session.execute(
169 select(Event).where(Event.thread_id == parent_comment.thread_id)
170 ).scalar_one_or_none()
171 discussion = session.execute(
172 select(Discussion).where(Discussion.thread_id == parent_comment.thread_id)
173 ).scalar_one_or_none()
174 if event:
175 # thread is an event thread
176 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one()
177 for user_id in user_ids_to_notify:
178 context = make_background_user_context(user_id=user_id)
179 notify(
180 session,
181 user_id=user_id,
182 topic_action=NotificationTopicAction.thread__reply,
183 key=str(occurrence.id),
184 data=notification_data_pb2.ThreadReply(
185 reply=reply,
186 event=event_to_pb(session, occurrence, context),
187 author=user_model_to_pb(author_user, session, context),
188 ),
189 moderation_state_id=occurrence.moderation_state_id,
190 )
191 elif discussion: 191 ↛ 207line 191 didn't jump to line 207 because the condition on line 191 was always true
192 # community discussion thread
193 for user_id in user_ids_to_notify:
194 context = make_background_user_context(user_id=user_id)
195 notify(
196 session,
197 user_id=user_id,
198 topic_action=NotificationTopicAction.thread__reply,
199 key=str(discussion.id),
200 data=notification_data_pb2.ThreadReply(
201 reply=reply,
202 discussion=discussion_to_pb(session, discussion, context),
203 author=user_model_to_pb(author_user, session, context),
204 ),
205 )
206 else:
207 raise NotImplementedError("I can only do event and discussion threads for now")
208 else:
209 raise Exception("Unknown depth")
212class Threads(threads_pb2_grpc.ThreadsServicer):
213 def GetThread(
214 self, request: threads_pb2.GetThreadReq, context: CouchersContext, session: Session
215 ) -> threads_pb2.GetThreadRes:
216 database_id, depth = unpack_thread_id(request.thread_id)
217 page_size = request.page_size if 0 < request.page_size < 100000 else 1000
218 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50
220 if depth == 0:
221 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
224 res = session.execute(
225 select(Comment, func.count(Reply.id))
226 .outerjoin(Reply, Reply.comment_id == Comment.id)
227 .where(Comment.thread_id == database_id)
228 .where(Comment.id < page_start)
229 .group_by(Comment.id)
230 .order_by(Comment.created.desc())
231 .limit(page_size + 1)
232 ).all()
233 replies = [
234 threads_pb2.Reply(
235 thread_id=pack_thread_id(r.id, 1),
236 content=r.content,
237 author_user_id=r.author_user_id,
238 created_time=Timestamp_from_datetime(r.created),
239 num_replies=n,
240 )
241 for r, n in res[:page_size]
242 ]
244 elif depth == 1:
245 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none():
246 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
248 res = (
249 session.execute( # type: ignore[assignment]
250 select(Reply)
251 .where(Reply.comment_id == database_id)
252 .where(Reply.id < page_start)
253 .order_by(Reply.created.desc())
254 .limit(page_size + 1)
255 )
256 .scalars()
257 .all()
258 )
259 replies = [
260 threads_pb2.Reply(
261 thread_id=pack_thread_id(r.id, 2),
262 content=r.content,
263 author_user_id=r.author_user_id,
264 created_time=Timestamp_from_datetime(r.created),
265 num_replies=0,
266 )
267 for r in res[:page_size]
268 ]
270 else:
271 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
273 if len(res) > page_size:
274 # There's more!
275 next_page_token = str(replies[-1].thread_id)
276 else:
277 next_page_token = ""
279 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token)
281 def PostReply(
282 self, request: threads_pb2.PostReplyReq, context: CouchersContext, session: Session
283 ) -> threads_pb2.PostReplyRes:
284 content = request.content.strip()
286 if content == "":
287 context.abort_with_error_code(grpc.StatusCode.INVALID_ARGUMENT, "invalid_comment")
289 database_id, depth = unpack_thread_id(request.thread_id)
290 if depth == 0:
291 object_to_add: Comment | Reply = Comment(
292 thread_id=database_id, author_user_id=context.user_id, content=content
293 )
294 elif depth == 1:
295 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=content)
296 else:
297 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
298 session.add(object_to_add)
299 try:
300 session.flush()
301 except sqlalchemy.exc.IntegrityError:
302 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
304 thread_id = pack_thread_id(object_to_add.id, depth + 1)
306 queue_job(
307 session,
308 job=generate_reply_notifications,
309 payload=jobs_pb2.GenerateReplyNotificationsPayload(
310 thread_id=thread_id,
311 ),
312 )
314 return threads_pb2.PostReplyRes(thread_id=thread_id)