Coverage for src/couchers/servicers/search.py: 83%

241 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-22 06:42 +0000

1""" 

2See //docs/search.md for overview. 

3""" 

4 

5from datetime import timedelta 

6 

7import grpc 

8from sqlalchemy.sql import and_, func, or_ 

9 

10from couchers import errors 

11from couchers.crypto import decrypt_page_token, encrypt_page_token 

12from couchers.models import ( 

13 Cluster, 

14 ClusterSubscription, 

15 Event, 

16 EventOccurrence, 

17 EventOccurrenceAttendee, 

18 EventOrganizer, 

19 EventSubscription, 

20 LanguageAbility, 

21 Node, 

22 Page, 

23 PageType, 

24 PageVersion, 

25 Reference, 

26 User, 

27) 

28from couchers.servicers.account import has_strong_verification 

29from couchers.servicers.api import ( 

30 fluency2sql, 

31 hostingstatus2sql, 

32 meetupstatus2sql, 

33 parkingdetails2sql, 

34 sleepingarrangement2sql, 

35 smokinglocation2sql, 

36 user_model_to_pb, 

37) 

38from couchers.servicers.communities import community_to_pb 

39from couchers.servicers.events import event_to_pb 

40from couchers.servicers.groups import group_to_pb 

41from couchers.servicers.pages import page_to_pb 

42from couchers.sql import couchers_select as select 

43from couchers.utils import ( 

44 create_coordinate, 

45 dt_from_millis, 

46 last_active_coarsen, 

47 millis_from_dt, 

48 now, 

49 to_aware_datetime, 

50) 

51from proto import search_pb2, search_pb2_grpc 

52 

53# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages 

54MAX_PAGINATION_LENGTH = 100 

55 

56REGCONFIG = "english" 

57TRI_SIMILARITY_THRESHOLD = 0.6 

58TRI_SIMILARITY_WEIGHT = 5 

59 

60 

61def _join_with_space(coalesces): 

62 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic. 

63 if not coalesces: 

64 return "" 

65 out = coalesces[0] 

66 for coalesce in coalesces[1:]: 

67 out += " " + coalesce 

68 return out 

69 

70 

71def _build_tsv(A, B=None, C=None, D=None): 

72 """ 

73 Given lists for A, B, C, and D, builds a tsvector from them. 

74 """ 

75 B = B or [] 

76 C = C or [] 

77 D = D or [] 

78 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A") 

79 if B: 

80 tsv = tsv.concat( 

81 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B") 

82 ) 

83 if C: 

84 tsv = tsv.concat( 

85 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C") 

86 ) 

87 if D: 

88 tsv = tsv.concat( 

89 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D") 

90 ) 

91 return tsv 

92 

93 

94def _build_doc(A, B=None, C=None, D=None): 

95 """ 

96 Builds the raw document (without to_tsvector and weighting), used for extracting snippet 

97 """ 

98 B = B or [] 

99 C = C or [] 

100 D = D or [] 

101 doc = _join_with_space([func.coalesce(bit, "") for bit in A]) 

102 if B: 

103 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B]) 

104 if C: 

105 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C]) 

106 if D: 

107 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D]) 

108 return doc 

109 

110 

111def _similarity(statement, text): 

112 return func.word_similarity(func.unaccent(statement), func.unaccent(text)) 

113 

114 

115def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=None, C=None, D=None): 

116 """ 

117 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search. 

118 

119 The four sets are in decreasing order of "importance" for ranking. 

120 

121 A should be the "title", the others can be anything. 

122 

123 If title_only=True, we only perform a trigram search against A only 

124 """ 

125 B = B or [] 

126 C = C or [] 

127 D = D or [] 

128 if not title_only: 

129 # a postgres tsquery object that can be used to match against a tsvector 

130 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

131 

132 # the tsvector object that we want to search against with our tsquery 

133 tsv = _build_tsv(A, B, C, D) 

134 

135 # document to generate snippet from 

136 doc = _build_doc(A, B, C, D) 

137 

138 title = _build_doc(A) 

139 

140 # trigram based text similarity between title and sql statement string 

141 sim = _similarity(statement, title) 

142 

143 # ranking algo, weigh the similarity a lot, the text-based ranking less 

144 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank") 

145 

146 # the snippet with results highlighted 

147 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

148 

149 def execute_search_statement(session, orig_statement): 

150 """ 

151 Does the right search filtering, limiting, and ordering for the initial statement 

152 """ 

153 return session.execute( 

154 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD)) 

155 .where(rank <= next_rank if next_rank is not None else True) 

156 .order_by(rank.desc()) 

157 .limit(page_size + 1) 

158 ).all() 

159 

160 else: 

161 title = _build_doc(A) 

162 

163 # trigram based text similarity between title and sql statement string 

164 sim = _similarity(statement, title) 

165 

166 # ranking algo, weigh the similarity a lot, the text-based ranking less 

167 rank = sim.label("rank") 

168 

169 # used only for headline 

170 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

171 doc = _build_doc(A, B, C, D) 

172 

173 # the snippet with results highlighted 

174 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

175 

176 def execute_search_statement(session, orig_statement): 

177 """ 

178 Does the right search filtering, limiting, and ordering for the initial statement 

179 """ 

180 return session.execute( 

181 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD) 

182 .where(rank <= next_rank if next_rank is not None else True) 

183 .order_by(rank.desc()) 

184 .limit(page_size + 1) 

185 ).all() 

186 

187 return rank, snippet, execute_search_statement 

188 

189 

190def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users): 

191 if not include_users: 

192 return [] 

193 rank, snippet, execute_search_statement = _gen_search_elements( 

194 search_statement, 

195 title_only, 

196 next_rank, 

197 page_size, 

198 [User.username, User.name], 

199 [User.city], 

200 [User.about_me], 

201 [User.things_i_like, User.about_place, User.additional_information], 

202 ) 

203 

204 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context)) 

205 

206 return [ 

207 search_pb2.Result( 

208 rank=rank, 

209 user=user_model_to_pb(page, session, context), 

210 snippet=snippet, 

211 ) 

212 for page, rank, snippet in users 

213 ] 

214 

215 

216def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides): 

217 rank, snippet, execute_search_statement = _gen_search_elements( 

218 search_statement, 

219 title_only, 

220 next_rank, 

221 page_size, 

222 [PageVersion.title], 

223 [PageVersion.address], 

224 [], 

225 [PageVersion.content], 

226 ) 

227 if not include_places and not include_guides: 

228 return [] 

229 

230 latest_pages = ( 

231 select(func.max(PageVersion.id).label("id")) 

232 .join(Page, Page.id == PageVersion.page_id) 

233 .where( 

234 or_( 

235 (Page.type == PageType.place) if include_places else False, 

236 (Page.type == PageType.guide) if include_guides else False, 

237 ) 

238 ) 

239 .group_by(PageVersion.page_id) 

240 .subquery() 

241 ) 

242 

243 pages = execute_search_statement( 

244 session, 

245 select(Page, rank, snippet) 

246 .join(PageVersion, PageVersion.page_id == Page.id) 

247 .join(latest_pages, latest_pages.c.id == PageVersion.id), 

248 ) 

249 

250 return [ 

251 search_pb2.Result( 

252 rank=rank, 

253 place=page_to_pb(session, page, context) if page.type == PageType.place else None, 

254 guide=page_to_pb(session, page, context) if page.type == PageType.guide else None, 

255 snippet=snippet, 

256 ) 

257 for page, rank, snippet in pages 

258 ] 

259 

260 

261def _search_events(session, search_statement, title_only, next_rank, page_size, context): 

262 rank, snippet, execute_search_statement = _gen_search_elements( 

263 search_statement, 

264 title_only, 

265 next_rank, 

266 page_size, 

267 [Event.title], 

268 [EventOccurrence.address, EventOccurrence.link], 

269 [], 

270 [EventOccurrence.content], 

271 ) 

272 

273 occurrences = execute_search_statement( 

274 session, 

275 select(EventOccurrence, rank, snippet) 

276 .join(Event, Event.id == EventOccurrence.event_id) 

277 .where(EventOccurrence.end_time >= func.now()), 

278 ) 

279 

280 return [ 

281 search_pb2.Result( 

282 rank=rank, 

283 event=event_to_pb(session, occurrence, context), 

284 snippet=snippet, 

285 ) 

286 for occurrence, rank, snippet in occurrences 

287 ] 

288 

289 

290def _search_clusters( 

291 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups 

292): 

293 if not include_communities and not include_groups: 

294 return [] 

295 

296 rank, snippet, execute_search_statement = _gen_search_elements( 

297 search_statement, 

298 title_only, 

299 next_rank, 

300 page_size, 

301 [Cluster.name], 

302 [PageVersion.address, PageVersion.title], 

303 [Cluster.description], 

304 [PageVersion.content], 

305 ) 

306 

307 latest_pages = ( 

308 select(func.max(PageVersion.id).label("id")) 

309 .join(Page, Page.id == PageVersion.page_id) 

310 .where(Page.type == PageType.main_page) 

311 .group_by(PageVersion.page_id) 

312 .subquery() 

313 ) 

314 

315 clusters = execute_search_statement( 

316 session, 

317 select(Cluster, rank, snippet) 

318 .join(Page, Page.owner_cluster_id == Cluster.id) 

319 .join(PageVersion, PageVersion.page_id == Page.id) 

320 .join(latest_pages, latest_pages.c.id == PageVersion.id) 

321 .where(Cluster.is_official_cluster if include_communities and not include_groups else True) 

322 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True), 

323 ) 

324 

325 return [ 

326 search_pb2.Result( 

327 rank=rank, 

328 community=( 

329 community_to_pb(session, cluster.official_cluster_for_node, context) 

330 if cluster.is_official_cluster 

331 else None 

332 ), 

333 group=group_to_pb(session, cluster, context) if not cluster.is_official_cluster else None, 

334 snippet=snippet, 

335 ) 

336 for cluster, rank, snippet in clusters 

337 ] 

338 

339 

340class Search(search_pb2_grpc.SearchServicer): 

341 def Search(self, request, context, session): 

342 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

343 # this is not an ideal page token, some results have equal rank (unlikely) 

344 next_rank = float(request.page_token) if request.page_token else None 

345 

346 all_results = ( 

347 _search_users( 

348 session, 

349 request.query, 

350 request.title_only, 

351 next_rank, 

352 page_size, 

353 context, 

354 request.include_users, 

355 ) 

356 + _search_pages( 

357 session, 

358 request.query, 

359 request.title_only, 

360 next_rank, 

361 page_size, 

362 context, 

363 request.include_places, 

364 request.include_guides, 

365 ) 

366 + _search_events( 

367 session, 

368 request.query, 

369 request.title_only, 

370 next_rank, 

371 page_size, 

372 context, 

373 ) 

374 + _search_clusters( 

375 session, 

376 request.query, 

377 request.title_only, 

378 next_rank, 

379 page_size, 

380 context, 

381 request.include_communities, 

382 request.include_groups, 

383 ) 

384 ) 

385 all_results.sort(key=lambda result: result.rank, reverse=True) 

386 return search_pb2.SearchRes( 

387 results=all_results[:page_size], 

388 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None, 

389 ) 

390 

391 def UserSearch(self, request, context, session): 

392 user = session.execute(select(User).where(User.id == context.user_id)).scalar_one() 

393 

394 statement = select(User).where_users_visible(context) 

395 if request.HasField("query"): 

396 if request.query_name_only: 

397 statement = statement.where( 

398 or_(User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%")) 

399 ) 

400 else: 

401 statement = statement.where( 

402 or_( 

403 User.name.ilike(f"%{request.query.value}%"), 

404 User.username.ilike(f"%{request.query.value}%"), 

405 User.city.ilike(f"%{request.query.value}%"), 

406 User.hometown.ilike(f"%{request.query.value}%"), 

407 User.about_me.ilike(f"%{request.query.value}%"), 

408 User.things_i_like.ilike(f"%{request.query.value}%"), 

409 User.about_place.ilike(f"%{request.query.value}%"), 

410 User.additional_information.ilike(f"%{request.query.value}%"), 

411 ) 

412 ) 

413 

414 if request.HasField("last_active"): 

415 raw_dt = to_aware_datetime(request.last_active) 

416 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt)) 

417 

418 if len(request.gender) > 0: 

419 if not has_strong_verification(session, user): 

420 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.NEED_STRONG_VERIFICATION) 

421 elif user.gender not in request.gender: 

422 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.MUST_INCLUDE_OWN_GENDER) 

423 else: 

424 statement = statement.where(User.gender.in_(request.gender)) 

425 

426 if len(request.hosting_status_filter) > 0: 

427 statement = statement.where( 

428 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter]) 

429 ) 

430 if len(request.meetup_status_filter) > 0: 

431 statement = statement.where( 

432 User.meetup_status.in_([meetupstatus2sql[status] for status in request.meetup_status_filter]) 

433 ) 

434 if len(request.smoking_location_filter) > 0: 

435 statement = statement.where( 

436 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter]) 

437 ) 

438 if len(request.sleeping_arrangement_filter) > 0: 

439 statement = statement.where( 

440 User.sleeping_arrangement.in_( 

441 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter] 

442 ) 

443 ) 

444 if len(request.parking_details_filter) > 0: 

445 statement = statement.where( 

446 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter]) 

447 ) 

448 # limits/default could be handled on the front end as well 

449 min_age = request.age_min.value if request.HasField("age_min") else 18 

450 max_age = request.age_max.value if request.HasField("age_max") else 200 

451 

452 statement = statement.where((User.age >= min_age) & (User.age <= max_age)) 

453 

454 # return results with by language code as only input 

455 # fluency in conversational or fluent 

456 

457 if len(request.language_ability_filter) > 0: 

458 language_options = [] 

459 for ability_filter in request.language_ability_filter: 

460 fluency_sql_value = fluency2sql.get(ability_filter.fluency) 

461 

462 if fluency_sql_value is None: 

463 continue 

464 language_options.append( 

465 and_( 

466 (LanguageAbility.language_code == ability_filter.code), 

467 (LanguageAbility.fluency >= (fluency_sql_value)), 

468 ) 

469 ) 

470 statement = statement.join(LanguageAbility, LanguageAbility.user_id == User.id) 

471 statement = statement.where(or_(*language_options)) 

472 

473 if request.HasField("profile_completed"): 

474 statement = statement.where(User.has_completed_profile == request.profile_completed.value) 

475 if request.HasField("guests"): 

476 statement = statement.where(User.max_guests >= request.guests.value) 

477 if request.HasField("last_minute"): 

478 statement = statement.where(User.last_minute == request.last_minute.value) 

479 if request.HasField("has_pets"): 

480 statement = statement.where(User.has_pets == request.has_pets.value) 

481 if request.HasField("accepts_pets"): 

482 statement = statement.where(User.accepts_pets == request.accepts_pets.value) 

483 if request.HasField("has_kids"): 

484 statement = statement.where(User.has_kids == request.has_kids.value) 

485 if request.HasField("accepts_kids"): 

486 statement = statement.where(User.accepts_kids == request.accepts_kids.value) 

487 if request.HasField("has_housemates"): 

488 statement = statement.where(User.has_housemates == request.has_housemates.value) 

489 if request.HasField("wheelchair_accessible"): 

490 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value) 

491 if request.HasField("smokes_at_home"): 

492 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value) 

493 if request.HasField("drinking_allowed"): 

494 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value) 

495 if request.HasField("drinks_at_home"): 

496 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value) 

497 if request.HasField("parking"): 

498 statement = statement.where(User.parking == request.parking.value) 

499 if request.HasField("camping_ok"): 

500 statement = statement.where(User.camping_ok == request.camping_ok.value) 

501 

502 if request.HasField("search_in_area"): 

503 # EPSG4326 measures distance in decimal degress 

504 # we want to check whether two circles overlap, so check if the distance between their centers is less 

505 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

506 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

507 statement = statement.where( 

508 func.ST_DWithin( 

509 # old: 

510 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

511 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

512 User.geom, 

513 search_point, 

514 (1000 + request.search_in_area.radius) / 111111, 

515 ) 

516 ) 

517 if request.HasField("search_in_rectangle"): 

518 statement = statement.where( 

519 func.ST_Within( 

520 User.geom, 

521 func.ST_MakeEnvelope( 

522 request.search_in_rectangle.lng_min, 

523 request.search_in_rectangle.lat_min, 

524 request.search_in_rectangle.lng_max, 

525 request.search_in_rectangle.lat_max, 

526 4326, 

527 ), 

528 ) 

529 ) 

530 if request.HasField("search_in_community_id"): 

531 # could do a join here as well, but this is just simpler 

532 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none() 

533 if not node: 

534 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

535 statement = statement.where(func.ST_Contains(node.geom, User.geom)) 

536 

537 if request.only_with_references: 

538 statement = statement.join(Reference, Reference.to_user_id == User.id) 

539 

540 # TODO: 

541 # google.protobuf.StringValue language = 11; 

542 # bool friends_only = 13; 

543 # google.protobuf.UInt32Value age_min = 14; 

544 # google.protobuf.UInt32Value age_max = 15; 

545 

546 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

547 next_recommendation_score = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10 

548 

549 statement = ( 

550 statement.where(User.recommendation_score <= next_recommendation_score) 

551 .order_by(User.recommendation_score.desc()) 

552 .limit(page_size + 1) 

553 ) 

554 users = session.execute(statement).scalars().all() 

555 

556 return search_pb2.UserSearchRes( 

557 results=[ 

558 search_pb2.Result( 

559 rank=1, 

560 user=user_model_to_pb(user, session, context), 

561 ) 

562 for user in users[:page_size] 

563 ], 

564 next_page_token=( 

565 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None 

566 ), 

567 ) 

568 

569 def EventSearch(self, request, context, session): 

570 statement = ( 

571 select(EventOccurrence).join(Event, Event.id == EventOccurrence.event_id).where(~EventOccurrence.is_deleted) 

572 ) 

573 

574 if request.HasField("query"): 

575 if request.query_title_only: 

576 statement = statement.where(Event.title.ilike(f"%{request.query.value}%")) 

577 else: 

578 statement = statement.where( 

579 or_( 

580 Event.title.ilike(f"%{request.query.value}%"), 

581 EventOccurrence.content.ilike(f"%{request.query.value}%"), 

582 EventOccurrence.address.ilike(f"%{request.query.value}%"), 

583 ) 

584 ) 

585 

586 if request.only_online: 

587 statement = statement.where(EventOccurrence.geom == None) 

588 elif request.only_offline: 

589 statement = statement.where(EventOccurrence.geom != None) 

590 

591 if request.subscribed or request.attending or request.organizing or request.my_communities: 

592 where_ = [] 

593 

594 if request.subscribed: 

595 statement = statement.outerjoin( 

596 EventSubscription, 

597 and_(EventSubscription.event_id == Event.id, EventSubscription.user_id == context.user_id), 

598 ) 

599 where_.append(EventSubscription.user_id != None) 

600 if request.organizing: 

601 statement = statement.outerjoin( 

602 EventOrganizer, 

603 and_(EventOrganizer.event_id == Event.id, EventOrganizer.user_id == context.user_id), 

604 ) 

605 where_.append(EventOrganizer.user_id != None) 

606 if request.attending: 

607 statement = statement.outerjoin( 

608 EventOccurrenceAttendee, 

609 and_( 

610 EventOccurrenceAttendee.occurrence_id == EventOccurrence.id, 

611 EventOccurrenceAttendee.user_id == context.user_id, 

612 ), 

613 ) 

614 where_.append(EventOccurrenceAttendee.user_id != None) 

615 if request.my_communities: 

616 my_communities = ( 

617 session.execute( 

618 select(Node.id) 

619 .join(Cluster, Cluster.parent_node_id == Node.id) 

620 .join(ClusterSubscription, ClusterSubscription.cluster_id == Cluster.id) 

621 .where(ClusterSubscription.user_id == context.user_id) 

622 .where(Cluster.is_official_cluster) 

623 .order_by(Node.id) 

624 .limit(100000) 

625 ) 

626 .scalars() 

627 .all() 

628 ) 

629 where_.append(Event.parent_node_id.in_(my_communities)) 

630 

631 statement = statement.where(or_(*where_)) 

632 

633 if not request.include_cancelled: 

634 statement = statement.where(~EventOccurrence.is_cancelled) 

635 

636 if request.HasField("search_in_area"): 

637 # EPSG4326 measures distance in decimal degress 

638 # we want to check whether two circles overlap, so check if the distance between their centers is less 

639 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

640 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

641 statement = statement.where( 

642 func.ST_DWithin( 

643 # old: 

644 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

645 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

646 EventOccurrence.geom, 

647 search_point, 

648 (1000 + request.search_in_area.radius) / 111111, 

649 ) 

650 ) 

651 if request.HasField("search_in_rectangle"): 

652 statement = statement.where( 

653 func.ST_Within( 

654 EventOccurrence.geom, 

655 func.ST_MakeEnvelope( 

656 request.search_in_rectangle.lng_min, 

657 request.search_in_rectangle.lat_min, 

658 request.search_in_rectangle.lng_max, 

659 request.search_in_rectangle.lat_max, 

660 4326, 

661 ), 

662 ) 

663 ) 

664 if request.HasField("search_in_community_id"): 

665 # could do a join here as well, but this is just simpler 

666 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none() 

667 if not node: 

668 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

669 statement = statement.where(func.ST_Contains(node.geom, EventOccurrence.geom)) 

670 

671 if request.HasField("after"): 

672 statement = statement.where(EventOccurrence.start_time > to_aware_datetime(request.after)) 

673 if request.HasField("before"): 

674 statement = statement.where(EventOccurrence.end_time < to_aware_datetime(request.before)) 

675 

676 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

677 # the page token is a unix timestamp of where we left off 

678 page_token = ( 

679 dt_from_millis(int(request.page_token)) if request.page_token and not request.page_number else now() 

680 ) 

681 page_number = request.page_number or 1 

682 # Calculate the offset for pagination 

683 offset = (page_number - 1) * page_size 

684 

685 if not request.past: 

686 statement = statement.where(EventOccurrence.end_time > page_token - timedelta(seconds=1)).order_by( 

687 EventOccurrence.start_time.asc() 

688 ) 

689 else: 

690 statement = statement.where(EventOccurrence.end_time < page_token + timedelta(seconds=1)).order_by( 

691 EventOccurrence.start_time.desc() 

692 ) 

693 

694 total_items = session.execute(select(func.count()).select_from(statement.subquery())).scalar() 

695 # Apply pagination by page number 

696 statement = statement.offset(offset).limit(page_size) if request.page_number else statement.limit(page_size + 1) 

697 occurrences = session.execute(statement).scalars().all() 

698 

699 return search_pb2.EventSearchRes( 

700 events=[event_to_pb(session, occurrence, context) for occurrence in occurrences[:page_size]], 

701 next_page_token=(str(millis_from_dt(occurrences[-1].end_time)) if len(occurrences) > page_size else None), 

702 total_items=total_items, 

703 )