Coverage for src/couchers/servicers/threads.py: 98%
51 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
1import logging
3import grpc
4import sqlalchemy.exc
5from sqlalchemy.sql import func
7from couchers import errors
8from couchers.models import Comment, Reply, Thread
9from couchers.sql import couchers_select as select
10from couchers.utils import Timestamp_from_datetime
11from proto import threads_pb2, threads_pb2_grpc
13logger = logging.getLogger(__name__)
15# Since the API exposes a single ID space regardless of nesting level,
16# we construct the API id by appending the nesting level to the
17# database ID.
20def pack_thread_id(database_id: int, depth: int) -> int:
21 return database_id * 10 + depth
24def unpack_thread_id(thread_id: int) -> (int, int):
25 """Returns (database_id, depth) tuple."""
26 return divmod(thread_id, 10)
29def total_num_responses(session, database_id):
30 """Return the total number of comments and replies to the thread with
31 database id database_id.
32 """
33 return (
34 session.execute(select(func.count()).select_from(Comment).where(Comment.thread_id == database_id)).scalar_one()
35 + session.execute(
36 select(func.count())
37 .select_from(Reply)
38 .join(Comment, Comment.id == Reply.comment_id)
39 .where(Comment.thread_id == database_id)
40 ).scalar_one()
41 )
44def thread_to_pb(session, database_id):
45 return threads_pb2.Thread(
46 thread_id=pack_thread_id(database_id, 0),
47 num_responses=total_num_responses(session, database_id),
48 )
51class Threads(threads_pb2_grpc.ThreadsServicer):
52 def GetThread(self, request, context, session):
53 database_id, depth = unpack_thread_id(request.thread_id)
54 page_size = request.page_size if 0 < request.page_size < 100000 else 1000
55 page_start = unpack_thread_id(int(request.page_token))[0] if request.page_token else 2**50
57 if depth == 0:
58 if not session.execute(select(Thread).where(Thread.id == database_id)).scalar_one_or_none():
59 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND)
61 res = session.execute(
62 select(Comment, func.count(Reply.id))
63 .outerjoin(Reply, Reply.comment_id == Comment.id)
64 .where(Comment.thread_id == database_id)
65 .where(Comment.id < page_start)
66 .group_by(Comment.id)
67 .order_by(Comment.created.desc())
68 .limit(page_size + 1)
69 ).all()
70 replies = [
71 threads_pb2.Reply(
72 thread_id=pack_thread_id(r.id, 1),
73 content=r.content,
74 author_user_id=r.author_user_id,
75 created_time=Timestamp_from_datetime(r.created),
76 num_replies=n,
77 )
78 for r, n in res[:page_size]
79 ]
81 elif depth == 1:
82 if not session.execute(select(Comment).where(Comment.id == database_id)).scalar_one_or_none():
83 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND)
85 res = (
86 session.execute(
87 select(Reply)
88 .where(Reply.comment_id == database_id)
89 .where(Reply.id < page_start)
90 .order_by(Reply.created.desc())
91 .limit(page_size + 1)
92 )
93 .scalars()
94 .all()
95 )
96 replies = [
97 threads_pb2.Reply(
98 thread_id=pack_thread_id(r.id, 2),
99 content=r.content,
100 author_user_id=r.author_user_id,
101 created_time=Timestamp_from_datetime(r.created),
102 num_replies=0,
103 )
104 for r in res[:page_size]
105 ]
107 else:
108 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND)
110 if len(res) > page_size:
111 # There's more!
112 next_page_token = str(replies[-1].thread_id)
113 else:
114 next_page_token = ""
116 return threads_pb2.GetThreadRes(replies=replies, next_page_token=next_page_token)
118 def PostReply(self, request, context, session):
119 database_id, depth = unpack_thread_id(request.thread_id)
120 if depth == 0:
121 object_to_add = Comment(thread_id=database_id, author_user_id=context.user_id, content=request.content)
122 elif depth == 1:
123 object_to_add = Reply(comment_id=database_id, author_user_id=context.user_id, content=request.content)
124 else:
125 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND)
126 session.add(object_to_add)
127 try:
128 session.flush()
129 except sqlalchemy.exc.IntegrityError:
130 context.abort(grpc.StatusCode.NOT_FOUND, errors.THREAD_NOT_FOUND)
132 return threads_pb2.PostReplyRes(thread_id=pack_thread_id(object_to_add.id, depth + 1))