Coverage for src/couchers/servicers/search.py: 83%
229 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
1"""
2See //docs/search.md for overview.
3"""
5from datetime import timedelta
7import grpc
8from sqlalchemy.sql import and_, func, or_
10from couchers import errors
11from couchers.crypto import decrypt_page_token, encrypt_page_token
12from couchers.models import (
13 Cluster,
14 ClusterSubscription,
15 Event,
16 EventOccurrence,
17 EventOccurrenceAttendee,
18 EventOrganizer,
19 EventSubscription,
20 Node,
21 Page,
22 PageType,
23 PageVersion,
24 Reference,
25 User,
26)
27from couchers.servicers.account import has_strong_verification
28from couchers.servicers.api import (
29 hostingstatus2sql,
30 meetupstatus2sql,
31 parkingdetails2sql,
32 sleepingarrangement2sql,
33 smokinglocation2sql,
34 user_model_to_pb,
35)
36from couchers.servicers.communities import community_to_pb
37from couchers.servicers.events import event_to_pb
38from couchers.servicers.groups import group_to_pb
39from couchers.servicers.pages import page_to_pb
40from couchers.sql import couchers_select as select
41from couchers.utils import (
42 create_coordinate,
43 dt_from_millis,
44 last_active_coarsen,
45 millis_from_dt,
46 now,
47 to_aware_datetime,
48)
49from proto import search_pb2, search_pb2_grpc
51# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages
52MAX_PAGINATION_LENGTH = 100
54REGCONFIG = "english"
55TRI_SIMILARITY_THRESHOLD = 0.6
56TRI_SIMILARITY_WEIGHT = 5
59def _join_with_space(coalesces):
60 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic.
61 if not coalesces:
62 return ""
63 out = coalesces[0]
64 for coalesce in coalesces[1:]:
65 out += " " + coalesce
66 return out
69def _build_tsv(A, B=None, C=None, D=None):
70 """
71 Given lists for A, B, C, and D, builds a tsvector from them.
72 """
73 B = B or []
74 C = C or []
75 D = D or []
76 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A")
77 if B:
78 tsv = tsv.concat(
79 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B")
80 )
81 if C:
82 tsv = tsv.concat(
83 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C")
84 )
85 if D:
86 tsv = tsv.concat(
87 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D")
88 )
89 return tsv
92def _build_doc(A, B=None, C=None, D=None):
93 """
94 Builds the raw document (without to_tsvector and weighting), used for extracting snippet
95 """
96 B = B or []
97 C = C or []
98 D = D or []
99 doc = _join_with_space([func.coalesce(bit, "") for bit in A])
100 if B:
101 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B])
102 if C:
103 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C])
104 if D:
105 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D])
106 return doc
109def _similarity(statement, text):
110 return func.word_similarity(func.unaccent(statement), func.unaccent(text))
113def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=None, C=None, D=None):
114 """
115 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search.
117 The four sets are in decreasing order of "importance" for ranking.
119 A should be the "title", the others can be anything.
121 If title_only=True, we only perform a trigram search against A only
122 """
123 B = B or []
124 C = C or []
125 D = D or []
126 if not title_only:
127 # a postgres tsquery object that can be used to match against a tsvector
128 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
130 # the tsvector object that we want to search against with our tsquery
131 tsv = _build_tsv(A, B, C, D)
133 # document to generate snippet from
134 doc = _build_doc(A, B, C, D)
136 title = _build_doc(A)
138 # trigram based text similarity between title and sql statement string
139 sim = _similarity(statement, title)
141 # ranking algo, weigh the similarity a lot, the text-based ranking less
142 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank")
144 # the snippet with results highlighted
145 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
147 def execute_search_statement(session, orig_statement):
148 """
149 Does the right search filtering, limiting, and ordering for the initial statement
150 """
151 return session.execute(
152 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD))
153 .where(rank <= next_rank if next_rank is not None else True)
154 .order_by(rank.desc())
155 .limit(page_size + 1)
156 ).all()
158 else:
159 title = _build_doc(A)
161 # trigram based text similarity between title and sql statement string
162 sim = _similarity(statement, title)
164 # ranking algo, weigh the similarity a lot, the text-based ranking less
165 rank = sim.label("rank")
167 # used only for headline
168 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
169 doc = _build_doc(A, B, C, D)
171 # the snippet with results highlighted
172 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
174 def execute_search_statement(session, orig_statement):
175 """
176 Does the right search filtering, limiting, and ordering for the initial statement
177 """
178 return session.execute(
179 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD)
180 .where(rank <= next_rank if next_rank is not None else True)
181 .order_by(rank.desc())
182 .limit(page_size + 1)
183 ).all()
185 return rank, snippet, execute_search_statement
188def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users):
189 if not include_users:
190 return []
191 rank, snippet, execute_search_statement = _gen_search_elements(
192 search_statement,
193 title_only,
194 next_rank,
195 page_size,
196 [User.username, User.name],
197 [User.city],
198 [User.about_me],
199 [User.my_travels, User.things_i_like, User.about_place, User.additional_information],
200 )
202 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context))
204 return [
205 search_pb2.Result(
206 rank=rank,
207 user=user_model_to_pb(page, session, context),
208 snippet=snippet,
209 )
210 for page, rank, snippet in users
211 ]
214def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides):
215 rank, snippet, execute_search_statement = _gen_search_elements(
216 search_statement,
217 title_only,
218 next_rank,
219 page_size,
220 [PageVersion.title],
221 [PageVersion.address],
222 [],
223 [PageVersion.content],
224 )
225 if not include_places and not include_guides:
226 return []
228 latest_pages = (
229 select(func.max(PageVersion.id).label("id"))
230 .join(Page, Page.id == PageVersion.page_id)
231 .where(
232 or_(
233 (Page.type == PageType.place) if include_places else False,
234 (Page.type == PageType.guide) if include_guides else False,
235 )
236 )
237 .group_by(PageVersion.page_id)
238 .subquery()
239 )
241 pages = execute_search_statement(
242 session,
243 select(Page, rank, snippet)
244 .join(PageVersion, PageVersion.page_id == Page.id)
245 .join(latest_pages, latest_pages.c.id == PageVersion.id),
246 )
248 return [
249 search_pb2.Result(
250 rank=rank,
251 place=page_to_pb(session, page, context) if page.type == PageType.place else None,
252 guide=page_to_pb(session, page, context) if page.type == PageType.guide else None,
253 snippet=snippet,
254 )
255 for page, rank, snippet in pages
256 ]
259def _search_events(session, search_statement, title_only, next_rank, page_size, context):
260 rank, snippet, execute_search_statement = _gen_search_elements(
261 search_statement,
262 title_only,
263 next_rank,
264 page_size,
265 [Event.title],
266 [EventOccurrence.address, EventOccurrence.link],
267 [],
268 [EventOccurrence.content],
269 )
271 occurrences = execute_search_statement(
272 session,
273 select(EventOccurrence, rank, snippet)
274 .join(Event, Event.id == EventOccurrence.event_id)
275 .where(EventOccurrence.end_time >= func.now()),
276 )
278 return [
279 search_pb2.Result(
280 rank=rank,
281 event=event_to_pb(session, occurrence, context),
282 snippet=snippet,
283 )
284 for occurrence, rank, snippet in occurrences
285 ]
288def _search_clusters(
289 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups
290):
291 if not include_communities and not include_groups:
292 return []
294 rank, snippet, execute_search_statement = _gen_search_elements(
295 search_statement,
296 title_only,
297 next_rank,
298 page_size,
299 [Cluster.name],
300 [PageVersion.address, PageVersion.title],
301 [Cluster.description],
302 [PageVersion.content],
303 )
305 latest_pages = (
306 select(func.max(PageVersion.id).label("id"))
307 .join(Page, Page.id == PageVersion.page_id)
308 .where(Page.type == PageType.main_page)
309 .group_by(PageVersion.page_id)
310 .subquery()
311 )
313 clusters = execute_search_statement(
314 session,
315 select(Cluster, rank, snippet)
316 .join(Page, Page.owner_cluster_id == Cluster.id)
317 .join(PageVersion, PageVersion.page_id == Page.id)
318 .join(latest_pages, latest_pages.c.id == PageVersion.id)
319 .where(Cluster.is_official_cluster if include_communities and not include_groups else True)
320 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True),
321 )
323 return [
324 search_pb2.Result(
325 rank=rank,
326 community=(
327 community_to_pb(session, cluster.official_cluster_for_node, context)
328 if cluster.is_official_cluster
329 else None
330 ),
331 group=group_to_pb(session, cluster, context) if not cluster.is_official_cluster else None,
332 snippet=snippet,
333 )
334 for cluster, rank, snippet in clusters
335 ]
338class Search(search_pb2_grpc.SearchServicer):
339 def Search(self, request, context, session):
340 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
341 # this is not an ideal page token, some results have equal rank (unlikely)
342 next_rank = float(request.page_token) if request.page_token else None
344 all_results = (
345 _search_users(
346 session,
347 request.query,
348 request.title_only,
349 next_rank,
350 page_size,
351 context,
352 request.include_users,
353 )
354 + _search_pages(
355 session,
356 request.query,
357 request.title_only,
358 next_rank,
359 page_size,
360 context,
361 request.include_places,
362 request.include_guides,
363 )
364 + _search_events(
365 session,
366 request.query,
367 request.title_only,
368 next_rank,
369 page_size,
370 context,
371 )
372 + _search_clusters(
373 session,
374 request.query,
375 request.title_only,
376 next_rank,
377 page_size,
378 context,
379 request.include_communities,
380 request.include_groups,
381 )
382 )
383 all_results.sort(key=lambda result: result.rank, reverse=True)
384 return search_pb2.SearchRes(
385 results=all_results[:page_size],
386 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None,
387 )
389 def UserSearch(self, request, context, session):
390 user = session.execute(select(User).where(User.id == context.user_id)).scalar_one()
392 statement = select(User).where_users_visible(context)
393 if request.HasField("query"):
394 if request.query_name_only:
395 statement = statement.where(
396 or_(User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%"))
397 )
398 else:
399 statement = statement.where(
400 or_(
401 User.name.ilike(f"%{request.query.value}%"),
402 User.username.ilike(f"%{request.query.value}%"),
403 User.city.ilike(f"%{request.query.value}%"),
404 User.hometown.ilike(f"%{request.query.value}%"),
405 User.about_me.ilike(f"%{request.query.value}%"),
406 User.my_travels.ilike(f"%{request.query.value}%"),
407 User.things_i_like.ilike(f"%{request.query.value}%"),
408 User.about_place.ilike(f"%{request.query.value}%"),
409 User.additional_information.ilike(f"%{request.query.value}%"),
410 )
411 )
413 if request.HasField("last_active"):
414 raw_dt = to_aware_datetime(request.last_active)
415 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt))
417 if len(request.gender) > 0:
418 if not has_strong_verification(session, user):
419 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.NEED_STRONG_VERIFICATION)
420 elif user.gender not in request.gender:
421 context.abort(grpc.StatusCode.FAILED_PRECONDITION, errors.MUST_INCLUDE_OWN_GENDER)
422 else:
423 statement = statement.where(User.gender.in_(request.gender))
425 if len(request.hosting_status_filter) > 0:
426 statement = statement.where(
427 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter])
428 )
429 if len(request.meetup_status_filter) > 0:
430 statement = statement.where(
431 User.meetup_status.in_([meetupstatus2sql[status] for status in request.meetup_status_filter])
432 )
433 if len(request.smoking_location_filter) > 0:
434 statement = statement.where(
435 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter])
436 )
437 if len(request.sleeping_arrangement_filter) > 0:
438 statement = statement.where(
439 User.sleeping_arrangement.in_(
440 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter]
441 )
442 )
443 if len(request.parking_details_filter) > 0:
444 statement = statement.where(
445 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter])
446 )
447 if request.HasField("profile_completed"):
448 statement = statement.where(User.has_completed_profile == request.profile_completed.value)
449 if request.HasField("guests"):
450 statement = statement.where(User.max_guests >= request.guests.value)
451 if request.HasField("last_minute"):
452 statement = statement.where(User.last_minute == request.last_minute.value)
453 if request.HasField("has_pets"):
454 statement = statement.where(User.has_pets == request.has_pets.value)
455 if request.HasField("accepts_pets"):
456 statement = statement.where(User.accepts_pets == request.accepts_pets.value)
457 if request.HasField("has_kids"):
458 statement = statement.where(User.has_kids == request.has_kids.value)
459 if request.HasField("accepts_kids"):
460 statement = statement.where(User.accepts_kids == request.accepts_kids.value)
461 if request.HasField("has_housemates"):
462 statement = statement.where(User.has_housemates == request.has_housemates.value)
463 if request.HasField("wheelchair_accessible"):
464 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value)
465 if request.HasField("smokes_at_home"):
466 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value)
467 if request.HasField("drinking_allowed"):
468 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value)
469 if request.HasField("drinks_at_home"):
470 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value)
471 if request.HasField("parking"):
472 statement = statement.where(User.parking == request.parking.value)
473 if request.HasField("camping_ok"):
474 statement = statement.where(User.camping_ok == request.camping_ok.value)
476 if request.HasField("search_in_area"):
477 # EPSG4326 measures distance in decimal degress
478 # we want to check whether two circles overlap, so check if the distance between their centers is less
479 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator)
480 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng)
481 statement = statement.where(
482 func.ST_DWithin(
483 # old:
484 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111
485 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius
486 User.geom,
487 search_point,
488 (1000 + request.search_in_area.radius) / 111111,
489 )
490 )
491 if request.HasField("search_in_rectangle"):
492 statement = statement.where(
493 func.ST_Within(
494 User.geom,
495 func.ST_MakeEnvelope(
496 request.search_in_rectangle.lng_min,
497 request.search_in_rectangle.lat_min,
498 request.search_in_rectangle.lng_max,
499 request.search_in_rectangle.lat_max,
500 4326,
501 ),
502 )
503 )
504 if request.HasField("search_in_community_id"):
505 # could do a join here as well, but this is just simpler
506 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none()
507 if not node:
508 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND)
509 statement = statement.where(func.ST_Contains(node.geom, User.geom))
511 if request.only_with_references:
512 statement = statement.join(Reference, Reference.to_user_id == User.id)
514 # TODO:
515 # google.protobuf.StringValue language = 11;
516 # bool friends_only = 13;
517 # google.protobuf.UInt32Value age_min = 14;
518 # google.protobuf.UInt32Value age_max = 15;
520 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
521 next_recommendation_score = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10
523 statement = (
524 statement.where(User.recommendation_score <= next_recommendation_score)
525 .order_by(User.recommendation_score.desc())
526 .limit(page_size + 1)
527 )
528 users = session.execute(statement).scalars().all()
530 return search_pb2.UserSearchRes(
531 results=[
532 search_pb2.Result(
533 rank=1,
534 user=user_model_to_pb(user, session, context),
535 )
536 for user in users[:page_size]
537 ],
538 next_page_token=(
539 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None
540 ),
541 )
543 def EventSearch(self, request, context, session):
544 statement = (
545 select(EventOccurrence).join(Event, Event.id == EventOccurrence.event_id).where(~EventOccurrence.is_deleted)
546 )
548 if request.HasField("query"):
549 if request.query_title_only:
550 statement = statement.where(Event.title.ilike(f"%{request.query.value}%"))
551 else:
552 statement = statement.where(
553 or_(
554 Event.title.ilike(f"%{request.query.value}%"),
555 EventOccurrence.content.ilike(f"%{request.query.value}%"),
556 EventOccurrence.address.ilike(f"%{request.query.value}%"),
557 )
558 )
560 if request.only_online:
561 statement = statement.where(EventOccurrence.geom == None)
562 elif request.only_offline:
563 statement = statement.where(EventOccurrence.geom != None)
565 if request.subscribed or request.attending or request.organizing or request.my_communities:
566 where_ = []
568 if request.subscribed:
569 statement = statement.outerjoin(
570 EventSubscription,
571 and_(EventSubscription.event_id == Event.id, EventSubscription.user_id == context.user_id),
572 )
573 where_.append(EventSubscription.user_id != None)
574 if request.organizing:
575 statement = statement.outerjoin(
576 EventOrganizer,
577 and_(EventOrganizer.event_id == Event.id, EventOrganizer.user_id == context.user_id),
578 )
579 where_.append(EventOrganizer.user_id != None)
580 if request.attending:
581 statement = statement.outerjoin(
582 EventOccurrenceAttendee,
583 and_(
584 EventOccurrenceAttendee.occurrence_id == EventOccurrence.id,
585 EventOccurrenceAttendee.user_id == context.user_id,
586 ),
587 )
588 where_.append(EventOccurrenceAttendee.user_id != None)
589 if request.my_communities:
590 my_communities = (
591 session.execute(
592 select(Node.id)
593 .join(Cluster, Cluster.parent_node_id == Node.id)
594 .join(ClusterSubscription, ClusterSubscription.cluster_id == Cluster.id)
595 .where(ClusterSubscription.user_id == context.user_id)
596 .where(Cluster.is_official_cluster)
597 .order_by(Node.id)
598 .limit(100000)
599 )
600 .scalars()
601 .all()
602 )
603 where_.append(Event.parent_node_id.in_(my_communities))
605 statement = statement.where(or_(*where_))
607 if not request.include_cancelled:
608 statement = statement.where(~EventOccurrence.is_cancelled)
610 if request.HasField("search_in_area"):
611 # EPSG4326 measures distance in decimal degress
612 # we want to check whether two circles overlap, so check if the distance between their centers is less
613 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator)
614 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng)
615 statement = statement.where(
616 func.ST_DWithin(
617 # old:
618 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111
619 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius
620 EventOccurrence.geom,
621 search_point,
622 (1000 + request.search_in_area.radius) / 111111,
623 )
624 )
625 if request.HasField("search_in_rectangle"):
626 statement = statement.where(
627 func.ST_Within(
628 EventOccurrence.geom,
629 func.ST_MakeEnvelope(
630 request.search_in_rectangle.lng_min,
631 request.search_in_rectangle.lat_min,
632 request.search_in_rectangle.lng_max,
633 request.search_in_rectangle.lat_max,
634 4326,
635 ),
636 )
637 )
638 if request.HasField("search_in_community_id"):
639 # could do a join here as well, but this is just simpler
640 node = session.execute(select(Node).where(Node.id == request.search_in_community_id)).scalar_one_or_none()
641 if not node:
642 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND)
643 statement = statement.where(func.ST_Contains(node.geom, EventOccurrence.geom))
645 if request.HasField("after"):
646 statement = statement.where(EventOccurrence.start_time > to_aware_datetime(request.after))
647 if request.HasField("before"):
648 statement = statement.where(EventOccurrence.end_time < to_aware_datetime(request.before))
650 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
651 # the page token is a unix timestamp of where we left off
652 page_token = (
653 dt_from_millis(int(request.page_token)) if request.page_token and not request.page_number else now()
654 )
655 page_number = request.page_number or 1
656 # Calculate the offset for pagination
657 offset = (page_number - 1) * page_size
659 if not request.past:
660 statement = statement.where(EventOccurrence.end_time > page_token - timedelta(seconds=1)).order_by(
661 EventOccurrence.start_time.asc()
662 )
663 else:
664 statement = statement.where(EventOccurrence.end_time < page_token + timedelta(seconds=1)).order_by(
665 EventOccurrence.start_time.desc()
666 )
668 total_items = session.execute(select(func.count()).select_from(statement.subquery())).scalar()
669 # Apply pagination by page number
670 statement = statement.offset(offset).limit(page_size) if request.page_number else statement.limit(page_size + 1)
671 occurrences = session.execute(statement).scalars().all()
673 return search_pb2.EventSearchRes(
674 events=[event_to_pb(session, occurrence, context) for occurrence in occurrences[:page_size]],
675 next_page_token=(str(millis_from_dt(occurrences[-1].end_time)) if len(occurrences) > page_size else None),
676 total_items=total_items,
677 )