Coverage for src/couchers/servicers/search.py: 80%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2See //docs/search.md for overview.
3"""
4import grpc
5from sqlalchemy.sql import func, or_
7from couchers import errors
8from couchers.crypto import decrypt_page_token, encrypt_page_token
9from couchers.db import session_scope
10from couchers.models import Cluster, Event, EventOccurrence, Node, Page, PageType, PageVersion, Reference, User
11from couchers.servicers.api import (
12 hostingstatus2sql,
13 parkingdetails2sql,
14 sleepingarrangement2sql,
15 smokinglocation2sql,
16 user_model_to_pb,
17)
18from couchers.servicers.communities import community_to_pb
19from couchers.servicers.events import event_to_pb
20from couchers.servicers.groups import group_to_pb
21from couchers.servicers.pages import page_to_pb
22from couchers.sql import couchers_select as select
23from couchers.utils import create_coordinate, last_active_coarsen, to_aware_datetime
24from proto import search_pb2, search_pb2_grpc
26# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages
27MAX_PAGINATION_LENGTH = 50
29REGCONFIG = "english"
30TRI_SIMILARITY_THRESHOLD = 0.6
31TRI_SIMILARITY_WEIGHT = 5
34def _join_with_space(coalesces):
35 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic.
36 if not coalesces:
37 return ""
38 out = coalesces[0]
39 for coalesce in coalesces[1:]:
40 out += " " + coalesce
41 return out
44def _build_tsv(A, B=[], C=[], D=[]):
45 """
46 Given lists for A, B, C, and D, builds a tsvector from them.
47 """
48 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A")
49 if B:
50 tsv = tsv.concat(
51 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B")
52 )
53 if C:
54 tsv = tsv.concat(
55 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C")
56 )
57 if D:
58 tsv = tsv.concat(
59 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D")
60 )
61 return tsv
64def _build_doc(A, B=[], C=[], D=[]):
65 """
66 Builds the raw document (without to_tsvector and weighting), used for extracting snippet
67 """
68 doc = _join_with_space([func.coalesce(bit, "") for bit in A])
69 if B:
70 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B])
71 if C:
72 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C])
73 if D:
74 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D])
75 return doc
78def _similarity(statement, text):
79 return func.word_similarity(func.unaccent(statement), func.unaccent(text))
82def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=[], C=[], D=[]):
83 """
84 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search.
86 The four sets are in decreasing order of "importance" for ranking.
88 A should be the "title", the others can be anything.
90 If title_only=True, we only perform a trigram search against A only
91 """
92 if not title_only:
93 # a postgres tsquery object that can be used to match against a tsvector
94 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
96 # the tsvector object that we want to search against with our tsquery
97 tsv = _build_tsv(A, B, C, D)
99 # document to generate snippet from
100 doc = _build_doc(A, B, C, D)
102 title = _build_doc(A)
104 # trigram based text similarity between title and sql statement string
105 sim = _similarity(statement, title)
107 # ranking algo, weigh the similarity a lot, the text-based ranking less
108 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank")
110 # the snippet with results highlighted
111 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
113 def execute_search_statement(session, orig_statement):
114 """
115 Does the right search filtering, limiting, and ordering for the initial statement
116 """
117 return session.execute(
118 (
119 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD))
120 .where(rank <= next_rank if next_rank is not None else True)
121 .order_by(rank.desc())
122 .limit(page_size + 1)
123 )
124 ).all()
126 else:
127 title = _build_doc(A)
129 # trigram based text similarity between title and sql statement string
130 sim = _similarity(statement, title)
132 # ranking algo, weigh the similarity a lot, the text-based ranking less
133 rank = sim.label("rank")
135 # used only for headline
136 tsq = func.websearch_to_tsquery(REGCONFIG, statement)
137 doc = _build_doc(A, B, C, D)
139 # the snippet with results highlighted
140 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet")
142 def execute_search_statement(session, orig_statement):
143 """
144 Does the right search filtering, limiting, and ordering for the initial statement
145 """
146 return session.execute(
147 (
148 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD)
149 .where(rank <= next_rank if next_rank is not None else True)
150 .order_by(rank.desc())
151 .limit(page_size + 1)
152 )
153 ).all()
155 return rank, snippet, execute_search_statement
158def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users):
159 if not include_users:
160 return []
161 rank, snippet, execute_search_statement = _gen_search_elements(
162 search_statement,
163 title_only,
164 next_rank,
165 page_size,
166 [User.username, User.name],
167 [User.city],
168 [User.about_me],
169 [User.my_travels, User.things_i_like, User.about_place, User.additional_information],
170 )
172 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context))
174 return [
175 search_pb2.Result(
176 rank=rank,
177 user=user_model_to_pb(page, session, context),
178 snippet=snippet,
179 )
180 for page, rank, snippet in users
181 ]
184def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides):
185 rank, snippet, execute_search_statement = _gen_search_elements(
186 search_statement,
187 title_only,
188 next_rank,
189 page_size,
190 [PageVersion.title],
191 [PageVersion.address],
192 [],
193 [PageVersion.content],
194 )
195 if not include_places and not include_guides:
196 return []
198 latest_pages = (
199 select(func.max(PageVersion.id).label("id"))
200 .join(Page, Page.id == PageVersion.page_id)
201 .where(
202 or_(
203 (Page.type == PageType.place) if include_places else False,
204 (Page.type == PageType.guide) if include_guides else False,
205 )
206 )
207 .group_by(PageVersion.page_id)
208 .subquery()
209 )
211 pages = execute_search_statement(
212 session,
213 select(Page, rank, snippet)
214 .join(PageVersion, PageVersion.page_id == Page.id)
215 .join(latest_pages, latest_pages.c.id == PageVersion.id),
216 )
218 return [
219 search_pb2.Result(
220 rank=rank,
221 place=page_to_pb(page, context) if page.type == PageType.place else None,
222 guide=page_to_pb(page, context) if page.type == PageType.guide else None,
223 snippet=snippet,
224 )
225 for page, rank, snippet in pages
226 ]
229def _search_events(session, search_statement, title_only, next_rank, page_size, context):
230 rank, snippet, execute_search_statement = _gen_search_elements(
231 search_statement,
232 title_only,
233 next_rank,
234 page_size,
235 [Event.title],
236 [EventOccurrence.address, EventOccurrence.link],
237 [],
238 [EventOccurrence.content],
239 )
241 occurrences = execute_search_statement(
242 session,
243 select(EventOccurrence, rank, snippet)
244 .join(Event, Event.id == EventOccurrence.event_id)
245 .where(EventOccurrence.end_time >= func.now()),
246 )
248 return [
249 search_pb2.Result(
250 rank=rank,
251 event=event_to_pb(session, occurrence, context),
252 snippet=snippet,
253 )
254 for occurrence, rank, snippet in occurrences
255 ]
258def _search_clusters(
259 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups
260):
261 if not include_communities and not include_groups:
262 return []
264 rank, snippet, execute_search_statement = _gen_search_elements(
265 search_statement,
266 title_only,
267 next_rank,
268 page_size,
269 [Cluster.name],
270 [PageVersion.address, PageVersion.title],
271 [Cluster.description],
272 [PageVersion.content],
273 )
275 latest_pages = (
276 select(func.max(PageVersion.id).label("id"))
277 .join(Page, Page.id == PageVersion.page_id)
278 .where(Page.type == PageType.main_page)
279 .group_by(PageVersion.page_id)
280 .subquery()
281 )
283 clusters = execute_search_statement(
284 session,
285 select(Cluster, rank, snippet)
286 .join(Page, Page.owner_cluster_id == Cluster.id)
287 .join(PageVersion, PageVersion.page_id == Page.id)
288 .join(latest_pages, latest_pages.c.id == PageVersion.id)
289 .where(Cluster.is_official_cluster if include_communities and not include_groups else True)
290 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True),
291 )
293 return [
294 search_pb2.Result(
295 rank=rank,
296 community=community_to_pb(cluster.official_cluster_for_node, context)
297 if cluster.is_official_cluster
298 else None,
299 group=group_to_pb(cluster, context) if not cluster.is_official_cluster else None,
300 snippet=snippet,
301 )
302 for cluster, rank, snippet in clusters
303 ]
306class Search(search_pb2_grpc.SearchServicer):
307 def Search(self, request, context):
308 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
309 # this is not an ideal page token, some results have equal rank (unlikely)
310 next_rank = float(request.page_token) if request.page_token else None
311 with session_scope() as session:
312 all_results = (
313 _search_users(
314 session,
315 request.query,
316 request.title_only,
317 next_rank,
318 page_size,
319 context,
320 request.include_users,
321 )
322 + _search_pages(
323 session,
324 request.query,
325 request.title_only,
326 next_rank,
327 page_size,
328 context,
329 request.include_places,
330 request.include_guides,
331 )
332 + _search_events(
333 session,
334 request.query,
335 request.title_only,
336 next_rank,
337 page_size,
338 context,
339 )
340 + _search_clusters(
341 session,
342 request.query,
343 request.title_only,
344 next_rank,
345 page_size,
346 context,
347 request.include_communities,
348 request.include_groups,
349 )
350 )
351 all_results.sort(key=lambda result: result.rank, reverse=True)
352 return search_pb2.SearchRes(
353 results=all_results[:page_size],
354 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None,
355 )
357 def UserSearch(self, request, context):
358 with session_scope() as session:
359 statement = select(User).where_users_visible(context)
360 if request.HasField("query"):
361 if request.query_name_only:
362 statement = statement.where(
363 or_(
364 User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%")
365 )
366 )
367 else:
368 statement = statement.where(
369 or_(
370 User.name.ilike(f"%{request.query.value}%"),
371 User.username.ilike(f"%{request.query.value}%"),
372 User.city.ilike(f"%{request.query.value}%"),
373 User.hometown.ilike(f"%{request.query.value}%"),
374 User.about_me.ilike(f"%{request.query.value}%"),
375 User.my_travels.ilike(f"%{request.query.value}%"),
376 User.things_i_like.ilike(f"%{request.query.value}%"),
377 User.about_place.ilike(f"%{request.query.value}%"),
378 User.additional_information.ilike(f"%{request.query.value}%"),
379 )
380 )
382 if request.HasField("last_active"):
383 raw_dt = to_aware_datetime(request.last_active)
384 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt))
386 if request.HasField("gender"):
387 statement = statement.where(User.gender.ilike(f"%{request.gender.value}%"))
389 if len(request.hosting_status_filter) > 0:
390 statement = statement.where(
391 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter])
392 )
393 if len(request.smoking_location_filter) > 0:
394 statement = statement.where(
395 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter])
396 )
397 if len(request.sleeping_arrangement_filter) > 0:
398 statement = statement.where(
399 User.sleeping_arrangement.in_(
400 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter]
401 )
402 )
403 if len(request.parking_details_filter) > 0:
404 statement = statement.where(
405 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter])
406 )
408 if request.HasField("guests"):
409 statement = statement.where(User.max_guests >= request.guests.value)
410 if request.HasField("last_minute"):
411 statement = statement.where(User.last_minute == request.last_minute.value)
412 if request.HasField("has_pets"):
413 statement = statement.where(User.has_pets == request.has_pets.value)
414 if request.HasField("accepts_pets"):
415 statement = statement.where(User.accepts_pets == request.accepts_pets.value)
416 if request.HasField("has_kids"):
417 statement = statement.where(User.has_kids == request.has_kids.value)
418 if request.HasField("accepts_kids"):
419 statement = statement.where(User.accepts_kids == request.accepts_kids.value)
420 if request.HasField("has_housemates"):
421 statement = statement.where(User.has_housemates == request.has_housemates.value)
422 if request.HasField("wheelchair_accessible"):
423 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value)
424 if request.HasField("smokes_at_home"):
425 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value)
426 if request.HasField("drinking_allowed"):
427 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value)
428 if request.HasField("drinks_at_home"):
429 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value)
430 if request.HasField("parking"):
431 statement = statement.where(User.parking == request.parking.value)
432 if request.HasField("camping_ok"):
433 statement = statement.where(User.camping_ok == request.camping_ok.value)
435 if request.HasField("search_in_area"):
436 # EPSG4326 measures distance in decimal degress
437 # we want to check whether two circles overlap, so check if the distance between their centers is less
438 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator)
439 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng)
440 statement = statement.where(
441 func.ST_DWithin(
442 # old:
443 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111
444 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius
445 User.geom,
446 search_point,
447 (1000 + request.search_in_area.radius) / 111111,
448 )
449 )
450 if request.HasField("search_in_community_id"):
451 # could do a join here as well, but this is just simpler
452 node = session.execute(
453 select(Node).where(Node.id == request.search_in_community_id)
454 ).scalar_one_or_none()
455 if not node:
456 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND)
457 statement = statement.where(func.ST_Contains(node.geom, User.geom))
459 if request.only_with_references:
460 statement = statement.join(Reference, Reference.to_user_id == User.id)
462 # TODO:
463 # google.protobuf.StringValue language = 11;
464 # bool friends_only = 13;
465 # google.protobuf.UInt32Value age_min = 14;
466 # google.protobuf.UInt32Value age_max = 15;
468 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH)
469 next_user_id = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10
471 statement = (
472 statement.where(User.recommendation_score <= next_user_id)
473 .order_by(User.recommendation_score.desc())
474 .limit(page_size + 1)
475 )
476 users = session.execute(statement).scalars().all()
478 return search_pb2.UserSearchRes(
479 results=[
480 search_pb2.Result(
481 rank=1,
482 user=user_model_to_pb(user, session, context),
483 )
484 for user in users[:page_size]
485 ],
486 next_page_token=encrypt_page_token(str(users[-1].recommendation_score))
487 if len(users) > page_size
488 else None,
489 )