Coverage for src/couchers/servicers/search.py: 83%

229 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-10-15 13:03 +0000

1""" 

2See //docs/search.md for overview. 

3""" 

4 

5from datetime import timedelta 

6 

7import grpc 

8from sqlalchemy.sql import and_, func, or_ 

9 

10from couchers import errors 

11from couchers.crypto import decrypt_page_token, encrypt_page_token 

12from couchers.models import ( 

13 Cluster, 

14 ClusterSubscription, 

15 Event, 

16 EventOccurrence, 

17 EventOccurrenceAttendee, 

18 EventOrganizer, 

19 EventSubscription, 

20 Node, 

21 Page, 

22 PageType, 

23 PageVersion, 

24 Reference, 

25 User, 

26) 

27from couchers.servicers.account import has_strong_verification 

28from couchers.servicers.api import ( 

29 hostingstatus2sql, 

30 meetupstatus2sql, 

31 parkingdetails2sql, 

32 sleepingarrangement2sql, 

33 smokinglocation2sql, 

34 user_model_to_pb, 

35) 

36from couchers.servicers.communities import community_to_pb 

37from couchers.servicers.events import event_to_pb 

38from couchers.servicers.groups import group_to_pb 

39from couchers.servicers.pages import page_to_pb 

40from couchers.sql import couchers_select as select 

41from couchers.utils import ( 

42 create_coordinate, 

43 dt_from_millis, 

44 last_active_coarsen, 

45 millis_from_dt, 

46 now, 

47 to_aware_datetime, 

48) 

49from proto import search_pb2, search_pb2_grpc 

50 

51# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages 

52MAX_PAGINATION_LENGTH = 100 

53 

54REGCONFIG = "english" 

55TRI_SIMILARITY_THRESHOLD = 0.6 

56TRI_SIMILARITY_WEIGHT = 5 

57 

58 

59def _join_with_space(coalesces): 

60 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic. 

61 if not coalesces: 

62 return "" 

63 out = coalesces[0] 

64 for coalesce in coalesces[1:]: 

65 out += " " + coalesce 

66 return out 

67 

68 

69def _build_tsv(A, B=None, C=None, D=None): 

70 """ 

71 Given lists for A, B, C, and D, builds a tsvector from them. 

72 """ 

73 B = B or [] 

74 C = C or [] 

75 D = D or [] 

76 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A") 

77 if B: 

78 tsv = tsv.concat( 

79 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B") 

80 ) 

81 if C: 

82 tsv = tsv.concat( 

83 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C") 

84 ) 

85 if D: 

86 tsv = tsv.concat( 

87 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D") 

88 ) 

89 return tsv 

90 

91 

92def _build_doc(A, B=None, C=None, D=None): 

93 """ 

94 Builds the raw document (without to_tsvector and weighting), used for extracting snippet 

95 """ 

96 B = B or [] 

97 C = C or [] 

98 D = D or [] 

99 doc = _join_with_space([func.coalesce(bit, "") for bit in A]) 

100 if B: 

101 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B]) 

102 if C: 

103 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C]) 

104 if D: 

105 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D]) 

106 return doc 

107 

108 

109def _similarity(statement, text): 

110 return func.word_similarity(func.unaccent(statement), func.unaccent(text)) 

111 

112 

113def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=None, C=None, D=None): 

114 """ 

115 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search. 

116 

117 The four sets are in decreasing order of "importance" for ranking. 

118 

119 A should be the "title", the others can be anything. 

120 

121 If title_only=True, we only perform a trigram search against A only 

122 """ 

123 B = B or [] 

124 C = C or [] 

125 D = D or [] 

126 if not title_only: 

127 # a postgres tsquery object that can be used to match against a tsvector 

128 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

129 

130 # the tsvector object that we want to search against with our tsquery 

131 tsv = _build_tsv(A, B, C, D) 

132 

133 # document to generate snippet from 

134 doc = _build_doc(A, B, C, D) 

135 

136 title = _build_doc(A) 

137 

138 # trigram based text similarity between title and sql statement string 

139 sim = _similarity(statement, title) 

140 

141 # ranking algo, weigh the similarity a lot, the text-based ranking less 

142 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank") 

143 

144 # the snippet with results highlighted 

145 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

146 

147 def execute_search_statement(session, orig_statement): 

148 """ 

149 Does the right search filtering, limiting, and ordering for the initial statement 

150 """ 

151 return session.execute( 

152 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD)) 

153 .where(rank <= next_rank if next_rank is not None else True) 

154 .order_by(rank.desc()) 

155 .limit(page_size + 1) 

156 ).all() 

157 

158 else: 

159 title = _build_doc(A) 

160 

161 # trigram based text similarity between title and sql statement string 

162 sim = _similarity(statement, title) 

163 

164 # ranking algo, weigh the similarity a lot, the text-based ranking less 

165 rank = sim.label("rank") 

166 

167 # used only for headline 

168 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

169 doc = _build_doc(A, B, C, D) 

170 

171 # the snippet with results highlighted 

172 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

173 

174 def execute_search_statement(session, orig_statement): 

175 """ 

176 Does the right search filtering, limiting, and ordering for the initial statement 

177 """ 

178 return session.execute( 

179 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD) 

180 .where(rank <= next_rank if next_rank is not None else True) 

181 .order_by(rank.desc()) 

182 .limit(page_size + 1) 

183 ).all() 

184 

185 return rank, snippet, execute_search_statement 

186 

187 

188def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users): 

189 if not include_users: 

190 return [] 

191 rank, snippet, execute_search_statement = _gen_search_elements( 

192 search_statement, 

193 title_only, 

194 next_rank, 

195 page_size, 

196 [User.username, User.name], 

197 [User.city], 

198 [User.about_me], 

199 [User.my_travels, User.things_i_like, User.about_place, User.additional_information], 

200 ) 

201 

202 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context)) 

203 

204 return [ 

205 search_pb2.Result( 

206 rank=rank, 

207 user=user_model_to_pb(page, session, context), 

208 snippet=snippet, 

209 ) 

210 for page, rank, snippet in users 

211 ] 

212 

213 

214def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides): 

215 rank, snippet, execute_search_statement = _gen_search_elements( 

216 search_statement, 

217 title_only, 

218 next_rank, 

219 page_size, 

220 [PageVersion.title], 

221 [PageVersion.address], 

222 [], 

223 [PageVersion.content], 

224 ) 

225 if not include_places and not include_guides: 

226 return [] 

227 

228 latest_pages = ( 

229 select(func.max(PageVersion.id).label("id")) 

230 .join(Page, Page.id == PageVersion.page_id) 

231 .where( 

232 or_( 

233 (Page.type == PageType.place) if include_places else False, 

234 (Page.type == PageType.guide) if include_guides else False, 

235 ) 

236 ) 

237 .group_by(PageVersion.page_id) 

238 .subquery() 

239 ) 

240 

241 pages = execute_search_statement( 

242 session, 

243 select(Page, rank, snippet) 

244 .join(PageVersion, PageVersion.page_id == Page.id) 

245 .join(latest_pages, latest_pages.c.id == PageVersion.id), 

246 ) 

247 

248 return [ 

249 search_pb2.Result( 

250 rank=rank, 

251 place=page_to_pb(session, page, context) if page.type == PageType.place else None, 

252 guide=page_to_pb(session, page, context) if page.type == PageType.guide else None, 

253 snippet=snippet, 

254 ) 

255 for page, rank, snippet in pages 

256 ] 

257 

258 

259def _search_events(session, search_statement, title_only, next_rank, page_size, context): 

260 rank, snippet, execute_search_statement = _gen_search_elements( 

261 search_statement, 

262 title_only, 

263 next_rank, 

264 page_size, 

265 [Event.title], 

266 [EventOccurrence.address, EventOccurrence.link], 

267 [], 

268 [EventOccurrence.content], 

269 ) 

270 

271 occurrences = execute_search_statement( 

272 session, 

273 select(EventOccurrence, rank, snippet) 

274 .join(Event, Event.id == EventOccurrence.event_id) 

275 .where(EventOccurrence.end_time >= func.now()), 

276 ) 

277 

278 return [ 

279 search_pb2.Result( 

280 rank=rank, 

281 event=event_to_pb(session, occurrence, context), 

282 snippet=snippet, 

283 ) 

284 for occurrence, rank, snippet in occurrences 

285 ] 

286 

287 

288def _search_clusters( 

289 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups 

290): 

291 if not include_communities and not include_groups: 

292 return [] 

293 

294 rank, snippet, execute_search_statement = _gen_search_elements( 

295 search_statement, 

296 title_only, 

297 next_rank, 

298 page_size, 

299 [Cluster.name], 

300 [PageVersion.address, PageVersion.title], 

301 [Cluster.description], 

302 [PageVersion.content], 

303 ) 

304 

305 latest_pages = ( 

306 select(func.max(PageVersion.id).label("id")) 

307 .join(Page, Page.id == PageVersion.page_id) 

308 .where(Page.type == PageType.main_page) 

309 .group_by(PageVersion.page_id) 

310 .subquery() 

311 ) 

312 

313 clusters = execute_search_statement( 

314 session, 

315 select(Cluster, rank, snippet) 

316 .join(Page, Page.owner_cluster_id == Cluster.id) 

317 .join(PageVersion, PageVersion.page_id == Page.id) 

318 .join(latest_pages, latest_pages.c.id == PageVersion.id) 

319 .where(Cluster.is_official_cluster if include_communities and not include_groups else True) 

320 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True), 

321 ) 

322 

323 return [ 

324 search_pb2.Result( 

325 rank=rank, 

326 community=( 

327 community_to_pb(session, cluster.official_cluster_for_node, context) 

328 if cluster.is_official_cluster 

329 else None 

330 ), 

331 group=group_to_pb(session, cluster, context) if not cluster.is_official_cluster else None, 

332 snippet=snippet, 

333 ) 

334 for cluster, rank, snippet in clusters 

335 ] 

336 

337 

338class Search(search_pb2_grpc.SearchServicer): 

339 def Search(self, request, context, session): 

340 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

341 # this is not an ideal page token, some results have equal rank (unlikely) 

342 next_rank = float(request.page_token) if request.page_token else None 

343 

344 all_results = ( 

345 _search_users( 

346 session, 

347 request.query, 

348 request.title_only, 

349 next_rank, 

350 page_size, 

351 context, 

352 request.include_users, 

353 ) 

354 + _search_pages( 

355 session, 

356 request.query, 

357 request.title_only, 

358 next_rank, 

359 page_size, 

360 context, 

361 request.include_places, 

362 request.include_guides, 

363 ) 

364 + _search_events( 

365 session, 

366 request.query, 

367 request.title_only, 

368 next_rank, 

369 page_size, 

370 context, 

371 ) 

372 + _search_clusters( 

373 session, 

374 request.query, 

375 request.title_only, 

376 next_rank, 

377 page_size, 

378 context, 

379 request.include_communities, 

380 request.include_groups, 

381 ) 

382 ) 

383 all_results.sort(key=lambda result: result.rank, reverse=True) 

384 return search_pb2.SearchRes( 

385 results=all_results[:page_size], 

386 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None, 

387 ) 

388 

389 def UserSearch(self, request, context, session): 

390 user = session.execute(select(User).where(User.id == context.user_id)).scalar_one() 

391 

392 statement = select(User).where_users_visible(context) 

393 if request.HasField("query"): 

394 if request.query_name_only: 

395 statement = statement.where( 

396 or_(User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%")) 

397 ) 

398 else: 

399 statement = statement.where( 

400 or_( 

401 User.name.ilike(f"%{request.query.value}%"), 

402 User.username.ilike(f"%{request.query.value}%"), 

403 User.city.ilike(f"%{request.query.value}%"), 

404 User.hometown.ilike(f"%{request.query.value}%"), 

405 User.about_me.ilike(f"%{request.query.value}%"), 

406 User.my_travels.ilike(f"%{request.query.value}%"), 

407 User.things_i_like.ilike(f"%{request.query.value}%"), 

408 User.about_place.ilike(f"%{request.query.value}%"), 

409 User.additional_information.ilike(f"%{request.query.value}%"), 

410 ) 

411 ) 

412 

413 if request.HasField("last_active"): 

414 raw_dt = to_aware_datetime(request.last_active) 

415 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt)) 

416 

417 if len(request.gender) > 0: 

418 if not has_strong_verification(session, user): 

419 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.NEED_STRONG_VERIFICATION) 

420 elif user.gender not in request.gender: 

421 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.MUST_INCLUDE_OWN_GENDER) 

422 else: 

423 statement = statement.where(User.gender.in_(request.gender)) 

424 

425 if len(request.hosting_status_filter) > 0: 

426 statement = statement.where( 

427 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter]) 

428 ) 

429 if len(request.meetup_status_filter) > 0: 

430 statement = statement.where( 

431 User.meetup_status.in_([meetupstatus2sql[status] for status in request.meetup_status_filter]) 

432 ) 

433 if len(request.smoking_location_filter) > 0: 

434 statement = statement.where( 

435 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter]) 

436 ) 

437 if len(request.sleeping_arrangement_filter) > 0: 

438 statement = statement.where( 

439 User.sleeping_arrangement.in_( 

440 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter] 

441 ) 

442 ) 

443 if len(request.parking_details_filter) > 0: 

444 statement = statement.where( 

445 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter]) 

446 ) 

447 if request.HasField("profile_completed"): 

448 statement = statement.where(User.has_completed_profile == request.profile_completed.value) 

449 if request.HasField("guests"): 

450 statement = statement.where(User.max_guests >= request.guests.value) 

451 if request.HasField("last_minute"): 

452 statement = statement.where(User.last_minute == request.last_minute.value) 

453 if request.HasField("has_pets"): 

454 statement = statement.where(User.has_pets == request.has_pets.value) 

455 if request.HasField("accepts_pets"): 

456 statement = statement.where(User.accepts_pets == request.accepts_pets.value) 

457 if request.HasField("has_kids"): 

458 statement = statement.where(User.has_kids == request.has_kids.value) 

459 if request.HasField("accepts_kids"): 

460 statement = statement.where(User.accepts_kids == request.accepts_kids.value) 

461 if request.HasField("has_housemates"): 

462 statement = statement.where(User.has_housemates == request.has_housemates.value) 

463 if request.HasField("wheelchair_accessible"): 

464 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value) 

465 if request.HasField("smokes_at_home"): 

466 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value) 

467 if request.HasField("drinking_allowed"): 

468 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value) 

469 if request.HasField("drinks_at_home"): 

470 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value) 

471 if request.HasField("parking"): 

472 statement = statement.where(User.parking == request.parking.value) 

473 if request.HasField("camping_ok"): 

474 statement = statement.where(User.camping_ok == request.camping_ok.value) 

475 

476 if request.HasField("search_in_area"): 

477 # EPSG4326 measures distance in decimal degress 

478 # we want to check whether two circles overlap, so check if the distance between their centers is less 

479 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

480 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

481 statement = statement.where( 

482 func.ST_DWithin( 

483 # old: 

484 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

485 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

486 User.geom, 

487 search_point, 

488 (1000 + request.search_in_area.radius) / 111111, 

489 ) 

490 ) 

491 if request.HasField("search_in_rectangle"): 

492 statement = statement.where( 

493 func.ST_Within( 

494 User.geom, 

495 func.ST_MakeEnvelope( 

496 request.search_in_rectangle.lng_min, 

497 request.search_in_rectangle.lat_min, 

498 request.search_in_rectangle.lng_max, 

499 request.search_in_rectangle.lat_max, 

500 4326, 

501 ), 

502 ) 

503 ) 

504 if request.HasField("search_in_community_id"): 

505 # could do a join here as well, but this is just simpler 

506 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none() 

507 if not node: 

508 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

509 statement = statement.where(func.ST_Contains(node.geom, User.geom)) 

510 

511 if request.only_with_references: 

512 statement = statement.join(Reference, Reference.to_user_id == User.id) 

513 

514 # TODO: 

515 # google.protobuf.StringValue language = 11; 

516 # bool friends_only = 13; 

517 # google.protobuf.UInt32Value age_min = 14; 

518 # google.protobuf.UInt32Value age_max = 15; 

519 

520 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

521 next_recommendation_score = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10 

522 

523 statement = ( 

524 statement.where(User.recommendation_score <= next_recommendation_score) 

525 .order_by(User.recommendation_score.desc()) 

526 .limit(page_size + 1) 

527 ) 

528 users = session.execute(statement).scalars().all() 

529 

530 return search_pb2.UserSearchRes( 

531 results=[ 

532 search_pb2.Result( 

533 rank=1, 

534 user=user_model_to_pb(user, session, context), 

535 ) 

536 for user in users[:page_size] 

537 ], 

538 next_page_token=( 

539 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None 

540 ), 

541 ) 

542 

543 def EventSearch(self, request, context, session): 

544 statement = ( 

545 select(EventOccurrence).join(Event, Event.id == EventOccurrence.event_id).where(~EventOccurrence.is_deleted) 

546 ) 

547 

548 if request.HasField("query"): 

549 if request.query_title_only: 

550 statement = statement.where(Event.title.ilike(f"%{request.query.value}%")) 

551 else: 

552 statement = statement.where( 

553 or_( 

554 Event.title.ilike(f"%{request.query.value}%"), 

555 EventOccurrence.content.ilike(f"%{request.query.value}%"), 

556 EventOccurrence.address.ilike(f"%{request.query.value}%"), 

557 ) 

558 ) 

559 

560 if request.only_online: 

561 statement = statement.where(EventOccurrence.geom == None) 

562 elif request.only_offline: 

563 statement = statement.where(EventOccurrence.geom != None) 

564 

565 if request.subscribed or request.attending or request.organizing or request.my_communities: 

566 where_ = [] 

567 

568 if request.subscribed: 

569 statement = statement.outerjoin( 

570 EventSubscription, 

571 and_(EventSubscription.event_id == Event.id, EventSubscription.user_id == context.user_id), 

572 ) 

573 where_.append(EventSubscription.user_id != None) 

574 if request.organizing: 

575 statement = statement.outerjoin( 

576 EventOrganizer, 

577 and_(EventOrganizer.event_id == Event.id, EventOrganizer.user_id == context.user_id), 

578 ) 

579 where_.append(EventOrganizer.user_id != None) 

580 if request.attending: 

581 statement = statement.outerjoin( 

582 EventOccurrenceAttendee, 

583 and_( 

584 EventOccurrenceAttendee.occurrence_id == EventOccurrence.id, 

585 EventOccurrenceAttendee.user_id == context.user_id, 

586 ), 

587 ) 

588 where_.append(EventOccurrenceAttendee.user_id != None) 

589 if request.my_communities: 

590 my_communities = ( 

591 session.execute( 

592 select(Node.id) 

593 .join(Cluster, Cluster.parent_node_id == Node.id) 

594 .join(ClusterSubscription, ClusterSubscription.cluster_id == Cluster.id) 

595 .where(ClusterSubscription.user_id == context.user_id) 

596 .where(Cluster.is_official_cluster) 

597 .order_by(Node.id) 

598 .limit(100000) 

599 ) 

600 .scalars() 

601 .all() 

602 ) 

603 where_.append(Event.parent_node_id.in_(my_communities)) 

604 

605 statement = statement.where(or_(*where_)) 

606 

607 if not request.include_cancelled: 

608 statement = statement.where(~EventOccurrence.is_cancelled) 

609 

610 if request.HasField("search_in_area"): 

611 # EPSG4326 measures distance in decimal degress 

612 # we want to check whether two circles overlap, so check if the distance between their centers is less 

613 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

614 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

615 statement = statement.where( 

616 func.ST_DWithin( 

617 # old: 

618 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

619 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

620 EventOccurrence.geom, 

621 search_point, 

622 (1000 + request.search_in_area.radius) / 111111, 

623 ) 

624 ) 

625 if request.HasField("search_in_rectangle"): 

626 statement = statement.where( 

627 func.ST_Within( 

628 EventOccurrence.geom, 

629 func.ST_MakeEnvelope( 

630 request.search_in_rectangle.lng_min, 

631 request.search_in_rectangle.lat_min, 

632 request.search_in_rectangle.lng_max, 

633 request.search_in_rectangle.lat_max, 

634 4326, 

635 ), 

636 ) 

637 ) 

638 if request.HasField("search_in_community_id"): 

639 # could do a join here as well, but this is just simpler 

640 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none() 

641 if not node: 

642 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

643 statement = statement.where(func.ST_Contains(node.geom, EventOccurrence.geom)) 

644 

645 if request.HasField("after"): 

646 statement = statement.where(EventOccurrence.start_time > to_aware_datetime(request.after)) 

647 if request.HasField("before"): 

648 statement = statement.where(EventOccurrence.end_time < to_aware_datetime(request.before)) 

649 

650 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

651 # the page token is a unix timestamp of where we left off 

652 page_token = ( 

653 dt_from_millis(int(request.page_token)) if request.page_token and not request.page_number else now() 

654 ) 

655 page_number = request.page_number or 1 

656 # Calculate the offset for pagination 

657 offset = (page_number - 1) * page_size 

658 

659 if not request.past: 

660 statement = statement.where(EventOccurrence.end_time > page_token - timedelta(seconds=1)).order_by( 

661 EventOccurrence.start_time.asc() 

662 ) 

663 else: 

664 statement = statement.where(EventOccurrence.end_time < page_token + timedelta(seconds=1)).order_by( 

665 EventOccurrence.start_time.desc() 

666 ) 

667 

668 total_items = session.execute(select(func.count()).select_from(statement.subquery())).scalar() 

669 # Apply pagination by page number 

670 statement = statement.offset(offset).limit(page_size) if request.page_number else statement.limit(page_size + 1) 

671 occurrences = session.execute(statement).scalars().all() 

672 

673 return search_pb2.EventSearchRes( 

674 events=[event_to_pb(session, occurrence, context) for occurrence in occurrences[:page_size]], 

675 next_page_token=(str(millis_from_dt(occurrences[-1].end_time)) if len(occurrences) > page_size else None), 

676 total_items=total_items, 

677 )