Coverage for src/couchers/servicers/search.py: 80%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

158 statements  

1""" 

2See //docs/search.md for overview. 

3""" 

4import grpc 

5from sqlalchemy.sql import func, or_ 

6 

7from couchers import errors 

8from couchers.crypto import decrypt_page_token, encrypt_page_token 

9from couchers.db import session_scope 

10from couchers.models import Cluster, Event, EventOccurrence, Node, Page, PageType, PageVersion, Reference, User 

11from couchers.servicers.api import ( 

12 hostingstatus2sql, 

13 parkingdetails2sql, 

14 sleepingarrangement2sql, 

15 smokinglocation2sql, 

16 user_model_to_pb, 

17) 

18from couchers.servicers.communities import community_to_pb 

19from couchers.servicers.events import event_to_pb 

20from couchers.servicers.groups import group_to_pb 

21from couchers.servicers.pages import page_to_pb 

22from couchers.sql import couchers_select as select 

23from couchers.utils import create_coordinate, last_active_coarsen, to_aware_datetime 

24from proto import search_pb2, search_pb2_grpc 

25 

26# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages 

27MAX_PAGINATION_LENGTH = 50 

28 

29REGCONFIG = "english" 

30TRI_SIMILARITY_THRESHOLD = 0.6 

31TRI_SIMILARITY_WEIGHT = 5 

32 

33 

34def _join_with_space(coalesces): 

35 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic. 

36 if not coalesces: 

37 return "" 

38 out = coalesces[0] 

39 for coalesce in coalesces[1:]: 

40 out += " " + coalesce 

41 return out 

42 

43 

44def _build_tsv(A, B=[], C=[], D=[]): 

45 """ 

46 Given lists for A, B, C, and D, builds a tsvector from them. 

47 """ 

48 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A") 

49 if B: 

50 tsv = tsv.concat( 

51 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B") 

52 ) 

53 if C: 

54 tsv = tsv.concat( 

55 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C") 

56 ) 

57 if D: 

58 tsv = tsv.concat( 

59 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D") 

60 ) 

61 return tsv 

62 

63 

64def _build_doc(A, B=[], C=[], D=[]): 

65 """ 

66 Builds the raw document (without to_tsvector and weighting), used for extracting snippet 

67 """ 

68 doc = _join_with_space([func.coalesce(bit, "") for bit in A]) 

69 if B: 

70 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B]) 

71 if C: 

72 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C]) 

73 if D: 

74 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D]) 

75 return doc 

76 

77 

78def _similarity(statement, text): 

79 return func.word_similarity(func.unaccent(statement), func.unaccent(text)) 

80 

81 

82def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=[], C=[], D=[]): 

83 """ 

84 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search. 

85 

86 The four sets are in decreasing order of "importance" for ranking. 

87 

88 A should be the "title", the others can be anything. 

89 

90 If title_only=True, we only perform a trigram search against A only 

91 """ 

92 if not title_only: 

93 # a postgres tsquery object that can be used to match against a tsvector 

94 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

95 

96 # the tsvector object that we want to search against with our tsquery 

97 tsv = _build_tsv(A, B, C, D) 

98 

99 # document to generate snippet from 

100 doc = _build_doc(A, B, C, D) 

101 

102 title = _build_doc(A) 

103 

104 # trigram based text similarity between title and sql statement string 

105 sim = _similarity(statement, title) 

106 

107 # ranking algo, weigh the similarity a lot, the text-based ranking less 

108 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank") 

109 

110 # the snippet with results highlighted 

111 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

112 

113 def execute_search_statement(session, orig_statement): 

114 """ 

115 Does the right search filtering, limiting, and ordering for the initial statement 

116 """ 

117 return session.execute( 

118 ( 

119 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD)) 

120 .where(rank <= next_rank if next_rank is not None else True) 

121 .order_by(rank.desc()) 

122 .limit(page_size + 1) 

123 ) 

124 ).all() 

125 

126 else: 

127 title = _build_doc(A) 

128 

129 # trigram based text similarity between title and sql statement string 

130 sim = _similarity(statement, title) 

131 

132 # ranking algo, weigh the similarity a lot, the text-based ranking less 

133 rank = sim.label("rank") 

134 

135 # used only for headline 

136 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

137 doc = _build_doc(A, B, C, D) 

138 

139 # the snippet with results highlighted 

140 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

141 

142 def execute_search_statement(session, orig_statement): 

143 """ 

144 Does the right search filtering, limiting, and ordering for the initial statement 

145 """ 

146 return session.execute( 

147 ( 

148 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD) 

149 .where(rank <= next_rank if next_rank is not None else True) 

150 .order_by(rank.desc()) 

151 .limit(page_size + 1) 

152 ) 

153 ).all() 

154 

155 return rank, snippet, execute_search_statement 

156 

157 

158def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users): 

159 if not include_users: 

160 return [] 

161 rank, snippet, execute_search_statement = _gen_search_elements( 

162 search_statement, 

163 title_only, 

164 next_rank, 

165 page_size, 

166 [User.username, User.name], 

167 [User.city], 

168 [User.about_me], 

169 [User.my_travels, User.things_i_like, User.about_place, User.additional_information], 

170 ) 

171 

172 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context)) 

173 

174 return [ 

175 search_pb2.Result( 

176 rank=rank, 

177 user=user_model_to_pb(page, session, context), 

178 snippet=snippet, 

179 ) 

180 for page, rank, snippet in users 

181 ] 

182 

183 

184def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides): 

185 rank, snippet, execute_search_statement = _gen_search_elements( 

186 search_statement, 

187 title_only, 

188 next_rank, 

189 page_size, 

190 [PageVersion.title], 

191 [PageVersion.address], 

192 [], 

193 [PageVersion.content], 

194 ) 

195 if not include_places and not include_guides: 

196 return [] 

197 

198 latest_pages = ( 

199 select(func.max(PageVersion.id).label("id")) 

200 .join(Page, Page.id == PageVersion.page_id) 

201 .where( 

202 or_( 

203 (Page.type == PageType.place) if include_places else False, 

204 (Page.type == PageType.guide) if include_guides else False, 

205 ) 

206 ) 

207 .group_by(PageVersion.page_id) 

208 .subquery() 

209 ) 

210 

211 pages = execute_search_statement( 

212 session, 

213 select(Page, rank, snippet) 

214 .join(PageVersion, PageVersion.page_id == Page.id) 

215 .join(latest_pages, latest_pages.c.id == PageVersion.id), 

216 ) 

217 

218 return [ 

219 search_pb2.Result( 

220 rank=rank, 

221 place=page_to_pb(page, context) if page.type == PageType.place else None, 

222 guide=page_to_pb(page, context) if page.type == PageType.guide else None, 

223 snippet=snippet, 

224 ) 

225 for page, rank, snippet in pages 

226 ] 

227 

228 

229def _search_events(session, search_statement, title_only, next_rank, page_size, context): 

230 rank, snippet, execute_search_statement = _gen_search_elements( 

231 search_statement, 

232 title_only, 

233 next_rank, 

234 page_size, 

235 [Event.title], 

236 [EventOccurrence.address, EventOccurrence.link], 

237 [], 

238 [EventOccurrence.content], 

239 ) 

240 

241 occurrences = execute_search_statement( 

242 session, 

243 select(EventOccurrence, rank, snippet) 

244 .join(Event, Event.id == EventOccurrence.event_id) 

245 .where(EventOccurrence.end_time >= func.now()), 

246 ) 

247 

248 return [ 

249 search_pb2.Result( 

250 rank=rank, 

251 event=event_to_pb(session, occurrence, context), 

252 snippet=snippet, 

253 ) 

254 for occurrence, rank, snippet in occurrences 

255 ] 

256 

257 

258def _search_clusters( 

259 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups 

260): 

261 if not include_communities and not include_groups: 

262 return [] 

263 

264 rank, snippet, execute_search_statement = _gen_search_elements( 

265 search_statement, 

266 title_only, 

267 next_rank, 

268 page_size, 

269 [Cluster.name], 

270 [PageVersion.address, PageVersion.title], 

271 [Cluster.description], 

272 [PageVersion.content], 

273 ) 

274 

275 latest_pages = ( 

276 select(func.max(PageVersion.id).label("id")) 

277 .join(Page, Page.id == PageVersion.page_id) 

278 .where(Page.type == PageType.main_page) 

279 .group_by(PageVersion.page_id) 

280 .subquery() 

281 ) 

282 

283 clusters = execute_search_statement( 

284 session, 

285 select(Cluster, rank, snippet) 

286 .join(Page, Page.owner_cluster_id == Cluster.id) 

287 .join(PageVersion, PageVersion.page_id == Page.id) 

288 .join(latest_pages, latest_pages.c.id == PageVersion.id) 

289 .where(Cluster.is_official_cluster if include_communities and not include_groups else True) 

290 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True), 

291 ) 

292 

293 return [ 

294 search_pb2.Result( 

295 rank=rank, 

296 community=community_to_pb(cluster.official_cluster_for_node, context) 

297 if cluster.is_official_cluster 

298 else None, 

299 group=group_to_pb(cluster, context) if not cluster.is_official_cluster else None, 

300 snippet=snippet, 

301 ) 

302 for cluster, rank, snippet in clusters 

303 ] 

304 

305 

306class Search(search_pb2_grpc.SearchServicer): 

307 def Search(self, request, context): 

308 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

309 # this is not an ideal page token, some results have equal rank (unlikely) 

310 next_rank = float(request.page_token) if request.page_token else None 

311 with session_scope() as session: 

312 all_results = ( 

313 _search_users( 

314 session, 

315 request.query, 

316 request.title_only, 

317 next_rank, 

318 page_size, 

319 context, 

320 request.include_users, 

321 ) 

322 + _search_pages( 

323 session, 

324 request.query, 

325 request.title_only, 

326 next_rank, 

327 page_size, 

328 context, 

329 request.include_places, 

330 request.include_guides, 

331 ) 

332 + _search_events( 

333 session, 

334 request.query, 

335 request.title_only, 

336 next_rank, 

337 page_size, 

338 context, 

339 ) 

340 + _search_clusters( 

341 session, 

342 request.query, 

343 request.title_only, 

344 next_rank, 

345 page_size, 

346 context, 

347 request.include_communities, 

348 request.include_groups, 

349 ) 

350 ) 

351 all_results.sort(key=lambda result: result.rank, reverse=True) 

352 return search_pb2.SearchRes( 

353 results=all_results[:page_size], 

354 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None, 

355 ) 

356 

357 def UserSearch(self, request, context): 

358 with session_scope() as session: 

359 statement = select(User).where_users_visible(context) 

360 if request.HasField("query"): 

361 if request.query_name_only: 

362 statement = statement.where( 

363 or_( 

364 User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%") 

365 ) 

366 ) 

367 else: 

368 statement = statement.where( 

369 or_( 

370 User.name.ilike(f"%{request.query.value}%"), 

371 User.username.ilike(f"%{request.query.value}%"), 

372 User.city.ilike(f"%{request.query.value}%"), 

373 User.hometown.ilike(f"%{request.query.value}%"), 

374 User.about_me.ilike(f"%{request.query.value}%"), 

375 User.my_travels.ilike(f"%{request.query.value}%"), 

376 User.things_i_like.ilike(f"%{request.query.value}%"), 

377 User.about_place.ilike(f"%{request.query.value}%"), 

378 User.additional_information.ilike(f"%{request.query.value}%"), 

379 ) 

380 ) 

381 

382 if request.HasField("last_active"): 

383 raw_dt = to_aware_datetime(request.last_active) 

384 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt)) 

385 

386 if request.HasField("gender"): 

387 statement = statement.where(User.gender.ilike(f"%{request.gender.value}%")) 

388 

389 if len(request.hosting_status_filter) > 0: 

390 statement = statement.where( 

391 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter]) 

392 ) 

393 if len(request.smoking_location_filter) > 0: 

394 statement = statement.where( 

395 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter]) 

396 ) 

397 if len(request.sleeping_arrangement_filter) > 0: 

398 statement = statement.where( 

399 User.sleeping_arrangement.in_( 

400 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter] 

401 ) 

402 ) 

403 if len(request.parking_details_filter) > 0: 

404 statement = statement.where( 

405 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter]) 

406 ) 

407 

408 if request.HasField("guests"): 

409 statement = statement.where(User.max_guests >= request.guests.value) 

410 if request.HasField("last_minute"): 

411 statement = statement.where(User.last_minute == request.last_minute.value) 

412 if request.HasField("has_pets"): 

413 statement = statement.where(User.has_pets == request.has_pets.value) 

414 if request.HasField("accepts_pets"): 

415 statement = statement.where(User.accepts_pets == request.accepts_pets.value) 

416 if request.HasField("has_kids"): 

417 statement = statement.where(User.has_kids == request.has_kids.value) 

418 if request.HasField("accepts_kids"): 

419 statement = statement.where(User.accepts_kids == request.accepts_kids.value) 

420 if request.HasField("has_housemates"): 

421 statement = statement.where(User.has_housemates == request.has_housemates.value) 

422 if request.HasField("wheelchair_accessible"): 

423 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value) 

424 if request.HasField("smokes_at_home"): 

425 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value) 

426 if request.HasField("drinking_allowed"): 

427 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value) 

428 if request.HasField("drinks_at_home"): 

429 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value) 

430 if request.HasField("parking"): 

431 statement = statement.where(User.parking == request.parking.value) 

432 if request.HasField("camping_ok"): 

433 statement = statement.where(User.camping_ok == request.camping_ok.value) 

434 

435 if request.HasField("search_in_area"): 

436 # EPSG4326 measures distance in decimal degress 

437 # we want to check whether two circles overlap, so check if the distance between their centers is less 

438 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

439 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

440 statement = statement.where( 

441 func.ST_DWithin( 

442 # old: 

443 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

444 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

445 User.geom, 

446 search_point, 

447 (1000 + request.search_in_area.radius) / 111111, 

448 ) 

449 ) 

450 if request.HasField("search_in_community_id"): 

451 # could do a join here as well, but this is just simpler 

452 node = session.execute( 

453 select(Node).where(Node.id == request.search_in_community_id) 

454 ).scalar_one_or_none() 

455 if not node: 

456 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

457 statement = statement.where(func.ST_Contains(node.geom, User.geom)) 

458 

459 if request.only_with_references: 

460 statement = statement.join(Reference, Reference.to_user_id == User.id) 

461 

462 # TODO: 

463 # google.protobuf.StringValue language = 11; 

464 # bool friends_only = 13; 

465 # google.protobuf.UInt32Value age_min = 14; 

466 # google.protobuf.UInt32Value age_max = 15; 

467 

468 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

469 next_user_id = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10 

470 

471 statement = ( 

472 statement.where(User.recommendation_score <= next_user_id) 

473 .order_by(User.recommendation_score.desc()) 

474 .limit(page_size + 1) 

475 ) 

476 users = session.execute(statement).scalars().all() 

477 

478 return search_pb2.UserSearchRes( 

479 results=[ 

480 search_pb2.Result( 

481 rank=1, 

482 user=user_model_to_pb(user, session, context), 

483 ) 

484 for user in users[:page_size] 

485 ], 

486 next_page_token=encrypt_page_token(str(users[-1].recommendation_score)) 

487 if len(users) > page_size 

488 else None, 

489 )