Coverage for src/couchers/servicers/groups.py: 80%

121 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-07-22 17:19 +0000

1import logging 

2from datetime import timedelta 

3 

4import grpc 

5from google.protobuf import empty_pb2 

6from sqlalchemy.sql import delete, func 

7 

8from couchers import errors 

9from couchers.db import can_moderate_node, get_node_parents_recursively, session_scope 

10from couchers.models import ( 

11 Cluster, 

12 ClusterRole, 

13 ClusterSubscription, 

14 Discussion, 

15 Event, 

16 EventOccurrence, 

17 Page, 

18 PageType, 

19 User, 

20) 

21from couchers.servicers.discussions import discussion_to_pb 

22from couchers.servicers.events import event_to_pb 

23from couchers.servicers.pages import page_to_pb 

24from couchers.sql import couchers_select as select 

25from couchers.utils import Timestamp_from_datetime, dt_from_millis, millis_from_dt, now 

26from proto import groups_pb2, groups_pb2_grpc 

27 

28logger = logging.getLogger(__name__) 

29 

30MAX_PAGINATION_LENGTH = 25 

31 

32 

33def _parents_to_pb(cluster: Cluster): 

34 with session_scope() as session: 

35 parents = get_node_parents_recursively(session, cluster.parent_node_id) 

36 return [ 

37 groups_pb2.Parent( 

38 community=groups_pb2.CommunityParent( 

39 community_id=node_id, 

40 name=cluster.name, 

41 slug=cluster.slug, 

42 description=cluster.description, 

43 ) 

44 ) 

45 for node_id, parent_node_id, level, cluster in parents 

46 ] + [ 

47 groups_pb2.Parent( 

48 group=groups_pb2.GroupParent( 

49 group_id=cluster.id, 

50 name=cluster.name, 

51 slug=cluster.slug, 

52 description=cluster.description, 

53 ) 

54 ) 

55 ] 

56 

57 

58def group_to_pb(cluster: Cluster, context): 

59 with session_scope() as session: 

60 can_moderate = can_moderate_node(session, context.user_id, cluster.parent_node_id) 

61 

62 member_count = session.execute( 

63 select(func.count()) 

64 .select_from(ClusterSubscription) 

65 .where_users_column_visible(context, ClusterSubscription.user_id) 

66 .where(ClusterSubscription.cluster_id == cluster.id) 

67 ).scalar_one() 

68 is_member = ( 

69 session.execute( 

70 select(ClusterSubscription) 

71 .where(ClusterSubscription.user_id == context.user_id) 

72 .where(ClusterSubscription.cluster_id == cluster.id) 

73 ).scalar_one_or_none() 

74 is not None 

75 ) 

76 

77 admin_count = session.execute( 

78 select(func.count()) 

79 .select_from(ClusterSubscription) 

80 .where_users_column_visible(context, ClusterSubscription.user_id) 

81 .where(ClusterSubscription.cluster_id == cluster.id) 

82 .where(ClusterSubscription.role == ClusterRole.admin) 

83 ).scalar_one() 

84 is_admin = ( 

85 session.execute( 

86 select(ClusterSubscription) 

87 .where(ClusterSubscription.user_id == context.user_id) 

88 .where(ClusterSubscription.cluster_id == cluster.id) 

89 .where(ClusterSubscription.role == ClusterRole.admin) 

90 ).scalar_one_or_none() 

91 is not None 

92 ) 

93 

94 return groups_pb2.Group( 

95 group_id=cluster.id, 

96 name=cluster.name, 

97 slug=cluster.slug, 

98 description=cluster.description, 

99 created=Timestamp_from_datetime(cluster.created), 

100 parents=_parents_to_pb(cluster), 

101 main_page=page_to_pb(cluster.main_page, context), 

102 member=is_member, 

103 admin=is_admin, 

104 member_count=member_count, 

105 admin_count=admin_count, 

106 can_moderate=can_moderate, 

107 ) 

108 

109 

110class Groups(groups_pb2_grpc.GroupsServicer): 

111 def GetGroup(self, request, context): 

112 with session_scope() as session: 

113 cluster = session.execute( 

114 select(Cluster) 

115 .where(~Cluster.is_official_cluster) # not an official group 

116 .where(Cluster.id == request.group_id) 

117 ).scalar_one_or_none() 

118 if not cluster: 

119 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

120 

121 return group_to_pb(cluster, context) 

122 

123 def ListAdmins(self, request, context): 

124 with session_scope() as session: 

125 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

126 next_admin_id = int(request.page_token) if request.page_token else 0 

127 cluster = session.execute( 

128 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

129 ).scalar_one_or_none() 

130 if not cluster: 

131 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

132 

133 admins = ( 

134 session.execute( 

135 select(User) 

136 .where_users_visible(context) 

137 .join(ClusterSubscription, ClusterSubscription.user_id == User.id) 

138 .where(ClusterSubscription.cluster_id == cluster.id) 

139 .where(ClusterSubscription.role == ClusterRole.admin) 

140 .where(User.id >= next_admin_id) 

141 .order_by(User.id) 

142 .limit(page_size + 1) 

143 ) 

144 .scalars() 

145 .all() 

146 ) 

147 return groups_pb2.ListAdminsRes( 

148 admin_user_ids=[admin.id for admin in admins[:page_size]], 

149 next_page_token=str(admins[-1].id) if len(admins) > page_size else None, 

150 ) 

151 

152 def ListMembers(self, request, context): 

153 with session_scope() as session: 

154 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

155 next_member_id = int(request.page_token) if request.page_token else 0 

156 cluster = session.execute( 

157 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

158 ).scalar_one_or_none() 

159 if not cluster: 

160 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

161 

162 members = ( 

163 session.execute( 

164 select(User) 

165 .join(ClusterSubscription, ClusterSubscription.user_id == User.id) 

166 .where_users_visible(context) 

167 .where(ClusterSubscription.cluster_id == cluster.id) 

168 .where(User.id >= next_member_id) 

169 .order_by(User.id) 

170 .limit(page_size + 1) 

171 ) 

172 .scalars() 

173 .all() 

174 ) 

175 return groups_pb2.ListMembersRes( 

176 member_user_ids=[member.id for member in members[:page_size]], 

177 next_page_token=str(members[-1].id) if len(members) > page_size else None, 

178 ) 

179 

180 def ListPlaces(self, request, context): 

181 with session_scope() as session: 

182 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

183 next_page_id = int(request.page_token) if request.page_token else 0 

184 cluster = session.execute( 

185 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

186 ).scalar_one_or_none() 

187 if not cluster: 

188 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

189 places = ( 

190 cluster.owned_pages.where(Page.type == PageType.place) 

191 .where(Page.id >= next_page_id) 

192 .order_by(Page.id) 

193 .limit(page_size + 1) 

194 .all() 

195 ) 

196 return groups_pb2.ListPlacesRes( 

197 places=[page_to_pb(page, context) for page in places[:page_size]], 

198 next_page_token=str(places[-1].id) if len(places) > page_size else None, 

199 ) 

200 

201 def ListGuides(self, request, context): 

202 with session_scope() as session: 

203 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

204 next_page_id = int(request.page_token) if request.page_token else 0 

205 cluster = session.execute( 

206 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

207 ).scalar_one_or_none() 

208 if not cluster: 

209 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

210 guides = ( 

211 cluster.owned_pages.where(Page.type == PageType.guide) 

212 .where(Page.id >= next_page_id) 

213 .order_by(Page.id) 

214 .limit(page_size + 1) 

215 .all() 

216 ) 

217 return groups_pb2.ListGuidesRes( 

218 guides=[page_to_pb(page, context) for page in guides[:page_size]], 

219 next_page_token=str(guides[-1].id) if len(guides) > page_size else None, 

220 ) 

221 

222 def ListEvents(self, request, context): 

223 with session_scope() as session: 

224 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

225 # the page token is a unix timestamp of where we left off 

226 page_token = dt_from_millis(int(request.page_token)) if request.page_token else now() 

227 

228 cluster = session.execute( 

229 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

230 ).scalar_one_or_none() 

231 if not cluster: 

232 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

233 

234 occurrences = ( 

235 select(EventOccurrence) 

236 .join(Event, Event.id == EventOccurrence.event_id) 

237 .where(Event.owner_cluster == cluster) 

238 ) 

239 

240 if not request.past: 

241 occurrences = occurrences.where(EventOccurrence.end_time > page_token - timedelta(seconds=1)).order_by( 

242 EventOccurrence.start_time.asc() 

243 ) 

244 else: 

245 occurrences = occurrences.where(EventOccurrence.end_time < page_token + timedelta(seconds=1)).order_by( 

246 EventOccurrence.start_time.desc() 

247 ) 

248 

249 occurrences = occurrences.limit(page_size + 1) 

250 occurrences = session.execute(occurrences).scalars().all() 

251 

252 return groups_pb2.ListEventsRes( 

253 events=[event_to_pb(session, occurrence, context) for occurrence in occurrences[:page_size]], 

254 next_page_token=str(millis_from_dt(occurrences[-1].end_time)) if len(occurrences) > page_size else None, 

255 ) 

256 

257 def ListDiscussions(self, request, context): 

258 with session_scope() as session: 

259 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

260 next_page_id = int(request.page_token) if request.page_token else 0 

261 cluster = session.execute( 

262 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

263 ).scalar_one_or_none() 

264 if not cluster: 

265 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

266 discussions = ( 

267 cluster.owned_discussions.where(Discussion.id >= next_page_id) 

268 .order_by(Discussion.id) 

269 .limit(page_size + 1) 

270 .all() 

271 ) 

272 return groups_pb2.ListDiscussionsRes( 

273 discussions=[discussion_to_pb(discussion, context) for discussion in discussions[:page_size]], 

274 next_page_token=str(discussions[-1].id) if len(discussions) > page_size else None, 

275 ) 

276 

277 def JoinGroup(self, request, context): 

278 with session_scope() as session: 

279 cluster = session.execute( 

280 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

281 ).scalar_one_or_none() 

282 if not cluster: 

283 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

284 

285 user_in_group = cluster.members.where(User.id == context.user_id).one_or_none() 

286 if user_in_group: 

287 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.ALREADY_IN_GROUP) 

288 

289 cluster.cluster_subscriptions.append( 

290 ClusterSubscription( 

291 user_id=context.user_id, 

292 role=ClusterRole.member, 

293 ) 

294 ) 

295 

296 return empty_pb2.Empty() 

297 

298 def LeaveGroup(self, request, context): 

299 with session_scope() as session: 

300 cluster = session.execute( 

301 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

302 ).scalar_one_or_none() 

303 if not cluster: 

304 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

305 

306 user_in_group = cluster.members.where(User.id == context.user_id).one_or_none() 

307 if not user_in_group: 

308 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.NOT_IN_GROUP) 

309 

310 session.execute( 

311 delete(ClusterSubscription) 

312 .where(ClusterSubscription.cluster_id == request.group_id) 

313 .where(ClusterSubscription.user_id == context.user_id) 

314 ) 

315 

316 return empty_pb2.Empty() 

317 

318 def ListUserGroups(self, request, context): 

319 with session_scope() as session: 

320 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

321 next_cluster_id = int(request.page_token) if request.page_token else 0 

322 user_id = request.user_id or context.user_id 

323 clusters = ( 

324 session.execute( 

325 select(Cluster) 

326 .join(ClusterSubscription, ClusterSubscription.cluster_id == Cluster.id) 

327 .where(ClusterSubscription.user_id == user_id) 

328 .where(~Cluster.is_official_cluster) # not an official group 

329 .where(Cluster.id >= next_cluster_id) 

330 .order_by(Cluster.id) 

331 .limit(page_size + 1) 

332 ) 

333 .scalars() 

334 .all() 

335 ) 

336 return groups_pb2.ListUserGroupsRes( 

337 groups=[group_to_pb(cluster, context) for cluster in clusters[:page_size]], 

338 next_page_token=str(clusters[-1].id) if len(clusters) > page_size else None, 

339 )