Coverage for src/couchers/servicers/search.py: 80%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2See //docs/search.md for overview.
3"""
5import grpc
6from sqlalchemy.sql import func, or_
8from couchers import errors
9from couchers.crypto import decrypt_page_token, encrypt_page_token
10from couchers.db import session_scope
11from couchers.models import Cluster, Event, EventOccurrence, Node, Page, PageType, PageVersion, Reference, User
12from couchers.servicers.api import (
13 hostingstatus2sql,
14 parkingdetails2sql,
15 sleepingarrangement2sql,
16 smokinglocation2sql,
17 user_model_to_pb,
18)
19from couchers.servicers.communities import community_to_pb
20from couchers.servicers.events import event_to_pb
21from couchers.servicers.groups import group_to_pb
22from couchers.servicers.pages import page_to_pb
23from couchers.sql import couchers_select as select
24from couchers.utils import create_coordinate, last_active_coarsen, to_aware_datetime
25from proto import search_pb2, search_pb2_grpc
27# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages
28MAX_PAGINATION_LENGTH = 50
30REGCONFIG = "english"
31TRI_SIMILARITY_THRESHOLD = 0.6
32TRI_SIMILARITY_WEIGHT = 5
35def _join_with_space(coalesces):
36 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic.
37 if not coalesces:
38 return ""
39 out = coalesces[0]
40 for coalesce in coalesces[1:]:
41 out += " " + coalesce
42 return out
45def _build_tsv(A, B=[], C=[], D=[]):
46 """
47 Given lists for A, B, C, and D, builds a tsvector from them.
48 """
49 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A")
50 if B:
51 tsv = tsv.concat(
52 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B")
53 )
54 if C:
55 tsv = tsv.concat(
56 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C")
57 )
58 if D:
59 tsv = tsv.concat(
60 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D")
61 )
62 return tsv
65def _build_doc(A, B=[], C=[], D=[]):
66 """
67 Builds the raw document (without to_tsvector and weighting), used for extracting snippet
68 """
69 doc = _join_with_space([func.coalesce(bit, "") for bit in A])
70 if B:
71 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B])
72 if C:
73 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C])
74 if D:
75 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D])
76 return doc
79def _similarity(statement, text):
80 return func.word_similarity(func.unaccent(statement), func.unaccent(text))
83def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=[], C=[], D=[]):
84 """
85 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search.
87 The four sets are in decreasing order of "importance" for ranking.
89 A should be the "title", the others can be anything.
91 If title_only=True, we only perform a trigram search against A only
92 """
93 if not title_only:
94 # a postgres tsquery object that can be used to match against a tsvector
95 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
97 # the tsvector object that we want to search against with our tsquery
98 tsv = _build_tsv(A, B, C, D)
100 # document to generate snippet from
101 doc = _build_doc(A, B, C, D)
103 title = _build_doc(A)
105 # trigram based text similarity between title and sql statement string
106 sim = _similarity(statement, title)
108 # ranking algo, weigh the similarity a lot, the text-based ranking less
109 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank")
111 # the snippet with results highlighted
112 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
114 def execute_search_statement(session, orig_statement):
115 """
116 Does the right search filtering, limiting, and ordering for the initial statement
117 """
118 return session.execute(
119 (
120 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD))
121 .where(rank <= next_rank if next_rank is not None else True)
122 .order_by(rank.desc())
123 .limit(page_size + 1)
124 )
125 ).all()
127 else:
128 title = _build_doc(A)
130 # trigram based text similarity between title and sql statement string
131 sim = _similarity(statement, title)
133 # ranking algo, weigh the similarity a lot, the text-based ranking less
134 rank = sim.label("rank")
136 # used only for headline
137 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
138 doc = _build_doc(A, B, C, D)
140 # the snippet with results highlighted
141 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
143 def execute_search_statement(session, orig_statement):
144 """
145 Does the right search filtering, limiting, and ordering for the initial statement
146 """
147 return session.execute(
148 (
149 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD)
150 .where(rank <= next_rank if next_rank is not None else True)
151 .order_by(rank.desc())
152 .limit(page_size + 1)
153 )
154 ).all()
156 return rank, snippet, execute_search_statement
159def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users):
160 if not include_users:
161 return []
162 rank, snippet, execute_search_statement = _gen_search_elements(
163 search_statement,
164 title_only,
165 next_rank,
166 page_size,
167 [User.username, User.name],
168 [User.city],
169 [User.about_me],
170 [User.my_travels, User.things_i_like, User.about_place, User.additional_information],
171 )
173 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context))
175 return [
176 search_pb2.Result(
177 rank=rank,
178 user=user_model_to_pb(page, session, context),
179 snippet=snippet,
180 )
181 for page, rank, snippet in users
182 ]
185def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides):
186 rank, snippet, execute_search_statement = _gen_search_elements(
187 search_statement,
188 title_only,
189 next_rank,
190 page_size,
191 [PageVersion.title],
192 [PageVersion.address],
193 [],
194 [PageVersion.content],
195 )
196 if not include_places and not include_guides:
197 return []
199 latest_pages = (
200 select(func.max(PageVersion.id).label("id"))
201 .join(Page, Page.id == PageVersion.page_id)
202 .where(
203 or_(
204 (Page.type == PageType.place) if include_places else False,
205 (Page.type == PageType.guide) if include_guides else False,
206 )
207 )
208 .group_by(PageVersion.page_id)
209 .subquery()
210 )
212 pages = execute_search_statement(
213 session,
214 select(Page, rank, snippet)
215 .join(PageVersion, PageVersion.page_id == Page.id)
216 .join(latest_pages, latest_pages.c.id == PageVersion.id),
217 )
219 return [
220 search_pb2.Result(
221 rank=rank,
222 place=page_to_pb(page, context) if page.type == PageType.place else None,
223 guide=page_to_pb(page, context) if page.type == PageType.guide else None,
224 snippet=snippet,
225 )
226 for page, rank, snippet in pages
227 ]
230def _search_events(session, search_statement, title_only, next_rank, page_size, context):
231 rank, snippet, execute_search_statement = _gen_search_elements(
232 search_statement,
233 title_only,
234 next_rank,
235 page_size,
236 [Event.title],
237 [EventOccurrence.address, EventOccurrence.link],
238 [],
239 [EventOccurrence.content],
240 )
242 occurrences = execute_search_statement(
243 session,
244 select(EventOccurrence, rank, snippet)
245 .join(Event, Event.id == EventOccurrence.event_id)
246 .where(EventOccurrence.end_time >= func.now()),
247 )
249 return [
250 search_pb2.Result(
251 rank=rank,
252 event=event_to_pb(session, occurrence, context),
253 snippet=snippet,
254 )
255 for occurrence, rank, snippet in occurrences
256 ]
259def _search_clusters(
260 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups
261):
262 if not include_communities and not include_groups:
263 return []
265 rank, snippet, execute_search_statement = _gen_search_elements(
266 search_statement,
267 title_only,
268 next_rank,
269 page_size,
270 [Cluster.name],
271 [PageVersion.address, PageVersion.title],
272 [Cluster.description],
273 [PageVersion.content],
274 )
276 latest_pages = (
277 select(func.max(PageVersion.id).label("id"))
278 .join(Page, Page.id == PageVersion.page_id)
279 .where(Page.type == PageType.main_page)
280 .group_by(PageVersion.page_id)
281 .subquery()
282 )
284 clusters = execute_search_statement(
285 session,
286 select(Cluster, rank, snippet)
287 .join(Page, Page.owner_cluster_id == Cluster.id)
288 .join(PageVersion, PageVersion.page_id == Page.id)
289 .join(latest_pages, latest_pages.c.id == PageVersion.id)
290 .where(Cluster.is_official_cluster if include_communities and not include_groups else True)
291 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True),
292 )
294 return [
295 search_pb2.Result(
296 rank=rank,
297 community=(
298 community_to_pb(cluster.official_cluster_for_node, context) if cluster.is_official_cluster else None
299 ),
300 group=group_to_pb(cluster, context) if not cluster.is_official_cluster else None,
301 snippet=snippet,
302 )
303 for cluster, rank, snippet in clusters
304 ]
307class Search(search_pb2_grpc.SearchServicer):
308 def Search(self, request, context):
309 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
310 # this is not an ideal page token, some results have equal rank (unlikely)
311 next_rank = float(request.page_token) if request.page_token else None
312 with session_scope() as session:
313 all_results = (
314 _search_users(
315 session,
316 request.query,
317 request.title_only,
318 next_rank,
319 page_size,
320 context,
321 request.include_users,
322 )
323 + _search_pages(
324 session,
325 request.query,
326 request.title_only,
327 next_rank,
328 page_size,
329 context,
330 request.include_places,
331 request.include_guides,
332 )
333 + _search_events(
334 session,
335 request.query,
336 request.title_only,
337 next_rank,
338 page_size,
339 context,
340 )
341 + _search_clusters(
342 session,
343 request.query,
344 request.title_only,
345 next_rank,
346 page_size,
347 context,
348 request.include_communities,
349 request.include_groups,
350 )
351 )
352 all_results.sort(key=lambda result: result.rank, reverse=True)
353 return search_pb2.SearchRes(
354 results=all_results[:page_size],
355 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None,
356 )
358 def UserSearch(self, request, context):
359 with session_scope() as session:
360 statement = select(User).where_users_visible(context)
361 if request.HasField("query"):
362 if request.query_name_only:
363 statement = statement.where(
364 or_(
365 User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%")
366 )
367 )
368 else:
369 statement = statement.where(
370 or_(
371 User.name.ilike(f"%{request.query.value}%"),
372 User.username.ilike(f"%{request.query.value}%"),
373 User.city.ilike(f"%{request.query.value}%"),
374 User.hometown.ilike(f"%{request.query.value}%"),
375 User.about_me.ilike(f"%{request.query.value}%"),
376 User.my_travels.ilike(f"%{request.query.value}%"),
377 User.things_i_like.ilike(f"%{request.query.value}%"),
378 User.about_place.ilike(f"%{request.query.value}%"),
379 User.additional_information.ilike(f"%{request.query.value}%"),
380 )
381 )
382 # if request.profile_completed:
383 # statement = statement.where(User.has_completed_profile == True)
385 if request.HasField("last_active"):
386 raw_dt = to_aware_datetime(request.last_active)
387 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt))
389 if request.HasField("gender"):
390 statement = statement.where(User.gender.ilike(f"%{request.gender.value}%"))
392 if len(request.hosting_status_filter) > 0:
393 statement = statement.where(
394 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter])
395 )
396 if len(request.smoking_location_filter) > 0:
397 statement = statement.where(
398 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter])
399 )
400 if len(request.sleeping_arrangement_filter) > 0:
401 statement = statement.where(
402 User.sleeping_arrangement.in_(
403 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter]
404 )
405 )
406 if len(request.parking_details_filter) > 0:
407 statement = statement.where(
408 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter])
409 )
410 if request.HasField("profile_completed"):
411 statement = statement.where(User.has_completed_profile == request.profile_completed.value)
412 if request.HasField("guests"):
413 statement = statement.where(User.max_guests >= request.guests.value)
414 if request.HasField("last_minute"):
415 statement = statement.where(User.last_minute == request.last_minute.value)
416 if request.HasField("has_pets"):
417 statement = statement.where(User.has_pets == request.has_pets.value)
418 if request.HasField("accepts_pets"):
419 statement = statement.where(User.accepts_pets == request.accepts_pets.value)
420 if request.HasField("has_kids"):
421 statement = statement.where(User.has_kids == request.has_kids.value)
422 if request.HasField("accepts_kids"):
423 statement = statement.where(User.accepts_kids == request.accepts_kids.value)
424 if request.HasField("has_housemates"):
425 statement = statement.where(User.has_housemates == request.has_housemates.value)
426 if request.HasField("wheelchair_accessible"):
427 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value)
428 if request.HasField("smokes_at_home"):
429 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value)
430 if request.HasField("drinking_allowed"):
431 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value)
432 if request.HasField("drinks_at_home"):
433 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value)
434 if request.HasField("parking"):
435 statement = statement.where(User.parking == request.parking.value)
436 if request.HasField("camping_ok"):
437 statement = statement.where(User.camping_ok == request.camping_ok.value)
439 if request.HasField("search_in_area"):
440 # EPSG4326 measures distance in decimal degress
441 # we want to check whether two circles overlap, so check if the distance between their centers is less
442 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator)
443 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng)
444 statement = statement.where(
445 func.ST_DWithin(
446 # old:
447 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111
448 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius
449 User.geom,
450 search_point,
451 (1000 + request.search_in_area.radius) / 111111,
452 )
453 )
454 if request.HasField("search_in_community_id"):
455 # could do a join here as well, but this is just simpler
456 node = session.execute(
457 select(Node).where(Node.id == request.search_in_community_id)
458 ).scalar_one_or_none()
459 if not node:
460 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND)
461 statement = statement.where(func.ST_Contains(node.geom, User.geom))
463 if request.only_with_references:
464 statement = statement.join(Reference, Reference.to_user_id == User.id)
466 # TODO:
467 # google.protobuf.StringValue language = 11;
468 # bool friends_only = 13;
469 # google.protobuf.UInt32Value age_min = 14;
470 # google.protobuf.UInt32Value age_max = 15;
472 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
473 next_user_id = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10
475 statement = (
476 statement.where(User.recommendation_score <= next_user_id)
477 .order_by(User.recommendation_score.desc())
478 .limit(page_size + 1)
479 )
480 users = session.execute(statement).scalars().all()
482 return search_pb2.UserSearchRes(
483 results=[
484 search_pb2.Result(
485 rank=1,
486 user=user_model_to_pb(user, session, context),
487 )
488 for user in users[:page_size]
489 ],
490 next_page_token=(
491 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None
492 ),
493 )