Coverage for app / backend / src / couchers / servicers / threads.py: 89%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
1import logging
3import grpc
4import sqlalchemy.exc
5from sqlalchemy import select
6from sqlalchemy.orm import Session
7from sqlalchemy.sql import func
9from couchers.context import CouchersContext, make_background_user_context
10from couchers.db import session_scope
11from couchers.jobs.enqueue import queue_job
12from couchers.models import Comment, Discussion, Event, EventOccurrence, Reply, Thread, User
13from couchers.models.notifications import NotificationTopicAction
14from couchers.notifications.notify import notify
15from couchers.proto import notification_data_pb2, threads_pb2, threads_pb2_grpc
16from couchers.proto.internal import jobs_pb2
17from couchers.servicers.api import user_model_to_pb
18from couchers.servicers.blocking import is_not_visible
19from couchers.sql import where_users_column_visible
20from couchers.utils import Timestamp_from_datetime
22logger = logging.getLogger(__name__)
24# Since the API exposes a single ID space regardless of nesting level,
25# we construct the API id by appending the nesting level to the
26# database ID.
29def pack_thread_id(database_id: int, depth: int) -> int:
30 return database_id * 10 + depth
33def unpack_thread_id(thread_id: int) -> tuple[int, int]:
34 """Returns (database_id, depth) tuple."""
35 return divmod(thread_id, 10)
38def total_num_responses(session: Session, database_id: int) -> int:
39 """Return the total number of comments and replies to the thread with
40 database id database_id.
41 """
42 comments = select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)
43 replies = (
44 select(func.count())
45 .select_from(Reply)
46 .join(Comment, Comment.id == Reply.comment_id)
47 .where(Comment.thread_id == database_id)
48 )
49 return session.execute(comments).scalar_one() + session.execute(replies).scalar_one()
52def thread_to_pb(session: Session, database_id: int) -> threads_pb2.Thread:
53 return threads_pb2.Thread(
54 thread_id=pack_thread_id(database_id, 0),
55 num_responses=total_num_responses(session, database_id),
56 )
59def generate_reply_notifications(payload: jobs_pb2.GenerateReplyNotificationsPayload) -> None:
60 from couchers.servicers.discussions import discussion_to_pb
61 from couchers.servicers.events import event_to_pb
63 with session_scope() as session:
64 database_id, depth = unpack_thread_id(payload.thread_id)
65 if depth == 1:
66 # this is a top-level Comment on a Thread attached to event, discussion, etc
67 comment = session.execute(select(Comment).where(Comment.id == database_id)).scalar_one()
68 thread = session.execute(select(Thread).where(Thread.id == comment.thread_id)).scalar_one()
69 author_user = session.execute(select(User).where(User.id == comment.author_user_id)).scalar_one()
70 # reply object for notif
71 reply = threads_pb2.Reply(
72 thread_id=payload.thread_id,
73 content=comment.content,
74 author_user_id=comment.author_user_id,
75 created_time=Timestamp_from_datetime(comment.created),
76 num_replies=0,
77 )
78 # figure out if the thread is related to an event or discussion
79 event = session.execute(select(Event).where(Event.thread_id == thread.id)).scalar_one_or_none()
80 discussion = session.execute(
81 select(Discussion).where(Discussion.thread_id == thread.id)
82 ).scalar_one_or_none()
83 if event:
84 # thread is an event thread
85 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one()
86 subscribed_user_ids = [user.id for user in event.subscribers]
87 attending_user_ids = [user.user_id for user in occurrence.attendances]
89 for user_id in set(subscribed_user_ids + attending_user_ids):
90 if is_not_visible(session, user_id, comment.author_user_id): 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 continue
92 if user_id == comment.author_user_id: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 continue
94 context = make_background_user_context(user_id=user_id)
95 notify(
96 session,
97 user_id=user_id,
98 topic_action=NotificationTopicAction.event__comment,
99 key=str(occurrence.id),
100 data=notification_data_pb2.EventComment(
101 reply=reply,
102 event=event_to_pb(session, occurrence, context),
103 author=user_model_to_pb(author_user, session, context),
104 ),
105 moderation_state_id=occurrence.moderation_state_id,
106 )
107 elif discussion: 107 ↛ 133line 107 didn't jump to line 133 because the condition on line 107 was always true
108 # community discussion thread
109 cluster = discussion.owner_cluster
111 if not cluster.is_official_cluster: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 raise NotImplementedError("Shouldn't have discussions under groups, only communities")
114 for user_id in [discussion.creator_user_id]:
115 if is_not_visible(session, user_id, comment.author_user_id): 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 continue
117 if user_id == comment.author_user_id: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 continue
120 context = make_background_user_context(user_id=user_id)
121 notify(
122 session,
123 user_id=user_id,
124 topic_action=NotificationTopicAction.discussion__comment,
125 key=str(discussion.id),
126 data=notification_data_pb2.DiscussionComment(
127 reply=reply,
128 discussion=discussion_to_pb(session, discussion, context),
129 author=user_model_to_pb(author_user, session, context),
130 ),
131 )
132 else:
133 raise NotImplementedError("I can only do event and discussion threads for now")
134 elif depth == 2: 134 ↛ 208line 134 didn't jump to line 208 because the condition on line 134 was always true
135 # this is a second-level reply to a comment
136 db_reply = session.execute(select(Reply).where(Reply.id == database_id)).scalar_one()
137 # the comment we're replying to
138 parent_comment = session.execute(select(Comment).where(Comment.id == db_reply.comment_id)).scalar_one()
139 context = make_background_user_context(user_id=db_reply.author_user_id)
140 thread_replies_author_user_ids = (
141 session.execute(
142 where_users_column_visible(
143 select(Reply.author_user_id).where(Reply.comment_id == parent_comment.id),
144 context,
145 Reply.author_user_id,
146 )
147 )
148 .scalars()
149 .all()
150 )
151 thread_user_ids = set(thread_replies_author_user_ids)
152 if not is_not_visible(session, parent_comment.author_user_id, db_reply.author_user_id): 152 ↛ 155line 152 didn't jump to line 155 because the condition on line 152 was always true
153 thread_user_ids.add(parent_comment.author_user_id)
155 author_user = session.execute(select(User).where(User.id == db_reply.author_user_id)).scalar_one()
157 user_ids_to_notify = set(thread_user_ids) - {db_reply.author_user_id}
159 reply = threads_pb2.Reply(
160 thread_id=payload.thread_id,
161 content=db_reply.content,
162 author_user_id=db_reply.author_user_id,
163 created_time=Timestamp_from_datetime(db_reply.created),
164 num_replies=0,
165 )
167 event = session.execute(
168 select(Event).where(Event.thread_id == parent_comment.thread_id)
169 ).scalar_one_or_none()
170 discussion = session.execute(
171 select(Discussion).where(Discussion.thread_id == parent_comment.thread_id)
172 ).scalar_one_or_none()
173 if event:
174 # thread is an event thread
175 occurrence = event.occurrences.order_by(EventOccurrence.id.desc()).limit(1).one()
176 for user_id in user_ids_to_notify:
177 context = make_background_user_context(user_id=user_id)
178 notify(
179 session,
180 user_id=user_id,
181 topic_action=NotificationTopicAction.thread__reply,
182 key=str(occurrence.id),
183 data=notification_data_pb2.ThreadReply(
184 reply=reply,
185 event=event_to_pb(session, occurrence, context),
186 author=user_model_to_pb(author_user, session, context),
187 ),
188 moderation_state_id=occurrence.moderation_state_id,
189 )
190 elif discussion: 190 ↛ 206line 190 didn't jump to line 206 because the condition on line 190 was always true
191 # community discussion thread
192 for user_id in user_ids_to_notify:
193 context = make_background_user_context(user_id=user_id)
194 notify(
195 session,
196 user_id=user_id,
197 topic_action=NotificationTopicAction.thread__reply,
198 key=str(discussion.id),
199 data=notification_data_pb2.ThreadReply(
200 reply=reply,
201 discussion=discussion_to_pb(session, discussion, context),
202 author=user_model_to_pb(author_user, session, context),
203 ),
204 )
205 else:
206 raise NotImplementedError("I can only do event and discussion threads for now")
207 else:
208 raise Exception("Unknown depth")
211class Threads(threads_pb2_grpc.ThreadsServicer):
212 def GetThread(
213 self, request: threads_pb2.GetThreadReq, context: CouchersContext, session: Session
214 ) -> threads_pb2.GetThreadRes:
215 database_id, depth = unpack_thread_id(request.thread_id)
216 page_size = request.page_size if 0 < request.page_size < 100000 else 1000
217 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50
219 if depth == 0:
220 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none(): 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true
221 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
223 res = session.execute(
224 select(Comment, func.count(Reply.id))
225 .outerjoin(Reply, Reply.comment_id == Comment.id)
226 .where(Comment.thread_id == database_id)
227 .where(Comment.id < page_start)
228 .group_by(Comment.id)
229 .order_by(Comment.created.desc())
230 .limit(page_size + 1)
231 ).all()
232 replies = [
233 threads_pb2.Reply(
234 thread_id=pack_thread_id(r.id, 1),
235 content=r.content,
236 author_user_id=r.author_user_id,
237 created_time=Timestamp_from_datetime(r.created),
238 num_replies=n,
239 )
240 for r, n in res[:page_size]
241 ]
243 elif depth == 1:
244 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none():
245 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
247 res = (
248 session.execute( # type: ignore[assignment]
249 select(Reply)
250 .where(Reply.comment_id == database_id)
251 .where(Reply.id < page_start)
252 .order_by(Reply.created.desc())
253 .limit(page_size + 1)
254 )
255 .scalars()
256 .all()
257 )
258 replies = [
259 threads_pb2.Reply(
260 thread_id=pack_thread_id(r.id, 2),
261 content=r.content,
262 author_user_id=r.author_user_id,
263 created_time=Timestamp_from_datetime(r.created),
264 num_replies=0,
265 )
266 for r in res[:page_size]
267 ]
269 else:
270 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
272 if len(res) > page_size:
273 # There's more!
274 next_page_token = str(replies[-1].thread_id)
275 else:
276 next_page_token = ""
278 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token)
280 def PostReply(
281 self, request: threads_pb2.PostReplyReq, context: CouchersContext, session: Session
282 ) -> threads_pb2.PostReplyRes:
283 content = request.content.strip()
285 if content == "":
286 context.abort_with_error_code(grpc.StatusCode.INVALID_ARGUMENT, "invalid_comment")
288 database_id, depth = unpack_thread_id(request.thread_id)
289 if depth == 0:
290 object_to_add: Comment | Reply = Comment(
291 thread_id=database_id, author_user_id=context.user_id, content=content
292 )
293 elif depth == 1:
294 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=content)
295 else:
296 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
297 session.add(object_to_add)
298 try:
299 session.flush()
300 except sqlalchemy.exc.IntegrityError:
301 context.abort_with_error_code(grpc.StatusCode.NOT_FOUND, "thread_not_found")
303 thread_id = pack_thread_id(object_to_add.id, depth + 1)
305 queue_job(
306 session,
307 job=generate_reply_notifications,
308 payload=jobs_pb2.GenerateReplyNotificationsPayload(
309 thread_id=thread_id,
310 ),
311 )
313 return threads_pb2.PostReplyRes(thread_id=thread_id)