Coverage for src/couchers/servicers/groups.py: 80%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-10-15 13:03 +0000

1import logging 

2from datetime import timedelta 

3 

4import grpc 

5from google.protobuf import empty_pb2 

6from sqlalchemy.sql import delete, func 

7 

8from couchers import errors 

9from couchers.db import can_moderate_node, get_node_parents_recursively 

10from couchers.models import ( 

11 Cluster, 

12 ClusterRole, 

13 ClusterSubscription, 

14 Discussion, 

15 Event, 

16 EventOccurrence, 

17 Page, 

18 PageType, 

19 User, 

20) 

21from couchers.servicers.discussions import discussion_to_pb 

22from couchers.servicers.events import event_to_pb 

23from couchers.servicers.pages import page_to_pb 

24from couchers.sql import couchers_select as select 

25from couchers.utils import Timestamp_from_datetime, dt_from_millis, millis_from_dt, now 

26from proto import groups_pb2, groups_pb2_grpc 

27 

28logger = logging.getLogger(__name__) 

29 

30MAX_PAGINATION_LENGTH = 25 

31 

32 

33def _parents_to_pb(session, cluster: Cluster): 

34 parents = get_node_parents_recursively(session, cluster.parent_node_id) 

35 return [ 

36 groups_pb2.Parent( 

37 community=groups_pb2.CommunityParent( 

38 community_id=node_id, 

39 name=cluster.name, 

40 slug=cluster.slug, 

41 description=cluster.description, 

42 ) 

43 ) 

44 for node_id, parent_node_id, level, cluster in parents 

45 ] + [ 

46 groups_pb2.Parent( 

47 group=groups_pb2.GroupParent( 

48 group_id=cluster.id, 

49 name=cluster.name, 

50 slug=cluster.slug, 

51 description=cluster.description, 

52 ) 

53 ) 

54 ] 

55 

56 

57def group_to_pb(session, cluster: Cluster, context): 

58 can_moderate = can_moderate_node(session, context.user_id, cluster.parent_node_id) 

59 

60 member_count = session.execute( 

61 select(func.count()) 

62 .select_from(ClusterSubscription) 

63 .where_users_column_visible(context, ClusterSubscription.user_id) 

64 .where(ClusterSubscription.cluster_id == cluster.id) 

65 ).scalar_one() 

66 is_member = ( 

67 session.execute( 

68 select(ClusterSubscription) 

69 .where(ClusterSubscription.user_id == context.user_id) 

70 .where(ClusterSubscription.cluster_id == cluster.id) 

71 ).scalar_one_or_none() 

72 is not None 

73 ) 

74 

75 admin_count = session.execute( 

76 select(func.count()) 

77 .select_from(ClusterSubscription) 

78 .where_users_column_visible(context, ClusterSubscription.user_id) 

79 .where(ClusterSubscription.cluster_id == cluster.id) 

80 .where(ClusterSubscription.role == ClusterRole.admin) 

81 ).scalar_one() 

82 is_admin = ( 

83 session.execute( 

84 select(ClusterSubscription) 

85 .where(ClusterSubscription.user_id == context.user_id) 

86 .where(ClusterSubscription.cluster_id == cluster.id) 

87 .where(ClusterSubscription.role == ClusterRole.admin) 

88 ).scalar_one_or_none() 

89 is not None 

90 ) 

91 

92 return groups_pb2.Group( 

93 group_id=cluster.id, 

94 name=cluster.name, 

95 slug=cluster.slug, 

96 description=cluster.description, 

97 created=Timestamp_from_datetime(cluster.created), 

98 parents=_parents_to_pb(session, cluster), 

99 main_page=page_to_pb(session, cluster.main_page, context), 

100 member=is_member, 

101 admin=is_admin, 

102 member_count=member_count, 

103 admin_count=admin_count, 

104 can_moderate=can_moderate, 

105 ) 

106 

107 

108class Groups(groups_pb2_grpc.GroupsServicer): 

109 def GetGroup(self, request, context, session): 

110 cluster = session.execute( 

111 select(Cluster) 

112 .where(~Cluster.is_official_cluster) # not an official group 

113 .where(Cluster.id == request.group_id) 

114 ).scalar_one_or_none() 

115 if not cluster: 

116 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

117 

118 return group_to_pb(session, cluster, context) 

119 

120 def ListAdmins(self, request, context, session): 

121 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

122 next_admin_id = int(request.page_token) if request.page_token else 0 

123 cluster = session.execute( 

124 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

125 ).scalar_one_or_none() 

126 if not cluster: 

127 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

128 

129 admins = ( 

130 session.execute( 

131 select(User) 

132 .where_users_visible(context) 

133 .join(ClusterSubscription, ClusterSubscription.user_id == User.id) 

134 .where(ClusterSubscription.cluster_id == cluster.id) 

135 .where(ClusterSubscription.role == ClusterRole.admin) 

136 .where(User.id >= next_admin_id) 

137 .order_by(User.id) 

138 .limit(page_size + 1) 

139 ) 

140 .scalars() 

141 .all() 

142 ) 

143 return groups_pb2.ListAdminsRes( 

144 admin_user_ids=[admin.id for admin in admins[:page_size]], 

145 next_page_token=str(admins[-1].id) if len(admins) > page_size else None, 

146 ) 

147 

148 def ListMembers(self, request, context, session): 

149 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

150 next_member_id = int(request.page_token) if request.page_token else 0 

151 cluster = session.execute( 

152 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

153 ).scalar_one_or_none() 

154 if not cluster: 

155 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

156 

157 members = ( 

158 session.execute( 

159 select(User) 

160 .join(ClusterSubscription, ClusterSubscription.user_id == User.id) 

161 .where_users_visible(context) 

162 .where(ClusterSubscription.cluster_id == cluster.id) 

163 .where(User.id >= next_member_id) 

164 .order_by(User.id) 

165 .limit(page_size + 1) 

166 ) 

167 .scalars() 

168 .all() 

169 ) 

170 return groups_pb2.ListMembersRes( 

171 member_user_ids=[member.id for member in members[:page_size]], 

172 next_page_token=str(members[-1].id) if len(members) > page_size else None, 

173 ) 

174 

175 def ListPlaces(self, request, context, session): 

176 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

177 next_page_id = int(request.page_token) if request.page_token else 0 

178 cluster = session.execute( 

179 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

180 ).scalar_one_or_none() 

181 if not cluster: 

182 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

183 places = ( 

184 cluster.owned_pages.where(Page.type == PageType.place) 

185 .where(Page.id >= next_page_id) 

186 .order_by(Page.id) 

187 .limit(page_size + 1) 

188 .all() 

189 ) 

190 return groups_pb2.ListPlacesRes( 

191 places=[page_to_pb(session, page, context) for page in places[:page_size]], 

192 next_page_token=str(places[-1].id) if len(places) > page_size else None, 

193 ) 

194 

195 def ListGuides(self, request, context, session): 

196 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

197 next_page_id = int(request.page_token) if request.page_token else 0 

198 cluster = session.execute( 

199 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

200 ).scalar_one_or_none() 

201 if not cluster: 

202 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

203 guides = ( 

204 cluster.owned_pages.where(Page.type == PageType.guide) 

205 .where(Page.id >= next_page_id) 

206 .order_by(Page.id) 

207 .limit(page_size + 1) 

208 .all() 

209 ) 

210 return groups_pb2.ListGuidesRes( 

211 guides=[page_to_pb(session, page, context) for page in guides[:page_size]], 

212 next_page_token=str(guides[-1].id) if len(guides) > page_size else None, 

213 ) 

214 

215 def ListEvents(self, request, context, session): 

216 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

217 # the page token is a unix timestamp of where we left off 

218 page_token = dt_from_millis(int(request.page_token)) if request.page_token else now() 

219 

220 cluster = session.execute( 

221 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

222 ).scalar_one_or_none() 

223 if not cluster: 

224 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

225 

226 occurrences = ( 

227 select(EventOccurrence) 

228 .join(Event, Event.id == EventOccurrence.event_id) 

229 .where(Event.owner_cluster == cluster) 

230 ) 

231 

232 if not request.past: 

233 occurrences = occurrences.where(EventOccurrence.end_time > page_token - timedelta(seconds=1)).order_by( 

234 EventOccurrence.start_time.asc() 

235 ) 

236 else: 

237 occurrences = occurrences.where(EventOccurrence.end_time < page_token + timedelta(seconds=1)).order_by( 

238 EventOccurrence.start_time.desc() 

239 ) 

240 

241 occurrences = occurrences.limit(page_size + 1) 

242 occurrences = session.execute(occurrences).scalars().all() 

243 

244 return groups_pb2.ListEventsRes( 

245 events=[event_to_pb(session, occurrence, context) for occurrence in occurrences[:page_size]], 

246 next_page_token=str(millis_from_dt(occurrences[-1].end_time)) if len(occurrences) > page_size else None, 

247 ) 

248 

249 def ListDiscussions(self, request, context, session): 

250 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

251 next_page_id = int(request.page_token) if request.page_token else 0 

252 cluster = session.execute( 

253 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

254 ).scalar_one_or_none() 

255 if not cluster: 

256 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

257 discussions = ( 

258 cluster.owned_discussions.where(Discussion.id >= next_page_id) 

259 .order_by(Discussion.id) 

260 .limit(page_size + 1) 

261 .all() 

262 ) 

263 return groups_pb2.ListDiscussionsRes( 

264 discussions=[discussion_to_pb(session, discussion, context) for discussion in discussions[:page_size]], 

265 next_page_token=str(discussions[-1].id) if len(discussions) > page_size else None, 

266 ) 

267 

268 def JoinGroup(self, request, context, session): 

269 cluster = session.execute( 

270 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

271 ).scalar_one_or_none() 

272 if not cluster: 

273 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

274 

275 user_in_group = cluster.members.where(User.id == context.user_id).one_or_none() 

276 if user_in_group: 

277 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.ALREADY_IN_GROUP) 

278 

279 cluster.cluster_subscriptions.append( 

280 ClusterSubscription( 

281 user_id=context.user_id, 

282 role=ClusterRole.member, 

283 ) 

284 ) 

285 

286 return empty_pb2.Empty() 

287 

288 def LeaveGroup(self, request, context, session): 

289 cluster = session.execute( 

290 select(Cluster).where(~Cluster.is_official_cluster).where(Cluster.id == request.group_id) 

291 ).scalar_one_or_none() 

292 if not cluster: 

293 context.abort(grpc.StatusCode.NOT_FOUND, errors.GROUP_NOT_FOUND) 

294 

295 user_in_group = cluster.members.where(User.id == context.user_id).one_or_none() 

296 if not user_in_group: 

297 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.NOT_IN_GROUP) 

298 

299 session.execute( 

300 delete(ClusterSubscription) 

301 .where(ClusterSubscription.cluster_id == request.group_id) 

302 .where(ClusterSubscription.user_id == context.user_id) 

303 ) 

304 

305 return empty_pb2.Empty() 

306 

307 def ListUserGroups(self, request, context, session): 

308 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

309 next_cluster_id = int(request.page_token) if request.page_token else 0 

310 user_id = request.user_id or context.user_id 

311 clusters = ( 

312 session.execute( 

313 select(Cluster) 

314 .join(ClusterSubscription, ClusterSubscription.cluster_id == Cluster.id) 

315 .where(ClusterSubscription.user_id == user_id) 

316 .where(~Cluster.is_official_cluster) # not an official group 

317 .where(Cluster.id >= next_cluster_id) 

318 .order_by(Cluster.id) 

319 .limit(page_size + 1) 

320 ) 

321 .scalars() 

322 .all() 

323 ) 

324 return groups_pb2.ListUserGroupsRes( 

325 groups=[group_to_pb(session, cluster, context) for cluster in clusters[:page_size]], 

326 next_page_token=str(clusters[-1].id) if len(clusters) > page_size else None, 

327 )