Coverage for src/couchers/servicers/search.py: 80%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

160 statements  

1""" 

2See //docs/search.md for overview. 

3""" 

4 

5import grpc 

6from sqlalchemy.sql import func, or_ 

7 

8from couchers import errors 

9from couchers.crypto import decrypt_page_token, encrypt_page_token 

10from couchers.db import session_scope 

11from couchers.models import Cluster, Event, EventOccurrence, Node, Page, PageType, PageVersion, Reference, User 

12from couchers.servicers.api import ( 

13 hostingstatus2sql, 

14 parkingdetails2sql, 

15 sleepingarrangement2sql, 

16 smokinglocation2sql, 

17 user_model_to_pb, 

18) 

19from couchers.servicers.communities import community_to_pb 

20from couchers.servicers.events import event_to_pb 

21from couchers.servicers.groups import group_to_pb 

22from couchers.servicers.pages import page_to_pb 

23from couchers.sql import couchers_select as select 

24from couchers.utils import create_coordinate, last_active_coarsen, to_aware_datetime 

25from proto import search_pb2, search_pb2_grpc 

26 

27# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages 

28MAX_PAGINATION_LENGTH = 50 

29 

30REGCONFIG = "english" 

31TRI_SIMILARITY_THRESHOLD = 0.6 

32TRI_SIMILARITY_WEIGHT = 5 

33 

34 

35def _join_with_space(coalesces): 

36 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic. 

37 if not coalesces: 

38 return "" 

39 out = coalesces[0] 

40 for coalesce in coalesces[1:]: 

41 out += " " + coalesce 

42 return out 

43 

44 

45def _build_tsv(A, B=[], C=[], D=[]): 

46 """ 

47 Given lists for A, B, C, and D, builds a tsvector from them. 

48 """ 

49 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A") 

50 if B: 

51 tsv = tsv.concat( 

52 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B") 

53 ) 

54 if C: 

55 tsv = tsv.concat( 

56 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C") 

57 ) 

58 if D: 

59 tsv = tsv.concat( 

60 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D") 

61 ) 

62 return tsv 

63 

64 

65def _build_doc(A, B=[], C=[], D=[]): 

66 """ 

67 Builds the raw document (without to_tsvector and weighting), used for extracting snippet 

68 """ 

69 doc = _join_with_space([func.coalesce(bit, "") for bit in A]) 

70 if B: 

71 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B]) 

72 if C: 

73 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C]) 

74 if D: 

75 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D]) 

76 return doc 

77 

78 

79def _similarity(statement, text): 

80 return func.word_similarity(func.unaccent(statement), func.unaccent(text)) 

81 

82 

83def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=[], C=[], D=[]): 

84 """ 

85 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search. 

86 

87 The four sets are in decreasing order of "importance" for ranking. 

88 

89 A should be the "title", the others can be anything. 

90 

91 If title_only=True, we only perform a trigram search against A only 

92 """ 

93 if not title_only: 

94 # a postgres tsquery object that can be used to match against a tsvector 

95 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

96 

97 # the tsvector object that we want to search against with our tsquery 

98 tsv = _build_tsv(A, B, C, D) 

99 

100 # document to generate snippet from 

101 doc = _build_doc(A, B, C, D) 

102 

103 title = _build_doc(A) 

104 

105 # trigram based text similarity between title and sql statement string 

106 sim = _similarity(statement, title) 

107 

108 # ranking algo, weigh the similarity a lot, the text-based ranking less 

109 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank") 

110 

111 # the snippet with results highlighted 

112 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

113 

114 def execute_search_statement(session, orig_statement): 

115 """ 

116 Does the right search filtering, limiting, and ordering for the initial statement 

117 """ 

118 return session.execute( 

119 ( 

120 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD)) 

121 .where(rank <= next_rank if next_rank is not None else True) 

122 .order_by(rank.desc()) 

123 .limit(page_size + 1) 

124 ) 

125 ).all() 

126 

127 else: 

128 title = _build_doc(A) 

129 

130 # trigram based text similarity between title and sql statement string 

131 sim = _similarity(statement, title) 

132 

133 # ranking algo, weigh the similarity a lot, the text-based ranking less 

134 rank = sim.label("rank") 

135 

136 # used only for headline 

137 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

138 doc = _build_doc(A, B, C, D) 

139 

140 # the snippet with results highlighted 

141 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

142 

143 def execute_search_statement(session, orig_statement): 

144 """ 

145 Does the right search filtering, limiting, and ordering for the initial statement 

146 """ 

147 return session.execute( 

148 ( 

149 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD) 

150 .where(rank <= next_rank if next_rank is not None else True) 

151 .order_by(rank.desc()) 

152 .limit(page_size + 1) 

153 ) 

154 ).all() 

155 

156 return rank, snippet, execute_search_statement 

157 

158 

159def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users): 

160 if not include_users: 

161 return [] 

162 rank, snippet, execute_search_statement = _gen_search_elements( 

163 search_statement, 

164 title_only, 

165 next_rank, 

166 page_size, 

167 [User.username, User.name], 

168 [User.city], 

169 [User.about_me], 

170 [User.my_travels, User.things_i_like, User.about_place, User.additional_information], 

171 ) 

172 

173 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context)) 

174 

175 return [ 

176 search_pb2.Result( 

177 rank=rank, 

178 user=user_model_to_pb(page, session, context), 

179 snippet=snippet, 

180 ) 

181 for page, rank, snippet in users 

182 ] 

183 

184 

185def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides): 

186 rank, snippet, execute_search_statement = _gen_search_elements( 

187 search_statement, 

188 title_only, 

189 next_rank, 

190 page_size, 

191 [PageVersion.title], 

192 [PageVersion.address], 

193 [], 

194 [PageVersion.content], 

195 ) 

196 if not include_places and not include_guides: 

197 return [] 

198 

199 latest_pages = ( 

200 select(func.max(PageVersion.id).label("id")) 

201 .join(Page, Page.id == PageVersion.page_id) 

202 .where( 

203 or_( 

204 (Page.type == PageType.place) if include_places else False, 

205 (Page.type == PageType.guide) if include_guides else False, 

206 ) 

207 ) 

208 .group_by(PageVersion.page_id) 

209 .subquery() 

210 ) 

211 

212 pages = execute_search_statement( 

213 session, 

214 select(Page, rank, snippet) 

215 .join(PageVersion, PageVersion.page_id == Page.id) 

216 .join(latest_pages, latest_pages.c.id == PageVersion.id), 

217 ) 

218 

219 return [ 

220 search_pb2.Result( 

221 rank=rank, 

222 place=page_to_pb(page, context) if page.type == PageType.place else None, 

223 guide=page_to_pb(page, context) if page.type == PageType.guide else None, 

224 snippet=snippet, 

225 ) 

226 for page, rank, snippet in pages 

227 ] 

228 

229 

230def _search_events(session, search_statement, title_only, next_rank, page_size, context): 

231 rank, snippet, execute_search_statement = _gen_search_elements( 

232 search_statement, 

233 title_only, 

234 next_rank, 

235 page_size, 

236 [Event.title], 

237 [EventOccurrence.address, EventOccurrence.link], 

238 [], 

239 [EventOccurrence.content], 

240 ) 

241 

242 occurrences = execute_search_statement( 

243 session, 

244 select(EventOccurrence, rank, snippet) 

245 .join(Event, Event.id == EventOccurrence.event_id) 

246 .where(EventOccurrence.end_time >= func.now()), 

247 ) 

248 

249 return [ 

250 search_pb2.Result( 

251 rank=rank, 

252 event=event_to_pb(session, occurrence, context), 

253 snippet=snippet, 

254 ) 

255 for occurrence, rank, snippet in occurrences 

256 ] 

257 

258 

259def _search_clusters( 

260 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups 

261): 

262 if not include_communities and not include_groups: 

263 return [] 

264 

265 rank, snippet, execute_search_statement = _gen_search_elements( 

266 search_statement, 

267 title_only, 

268 next_rank, 

269 page_size, 

270 [Cluster.name], 

271 [PageVersion.address, PageVersion.title], 

272 [Cluster.description], 

273 [PageVersion.content], 

274 ) 

275 

276 latest_pages = ( 

277 select(func.max(PageVersion.id).label("id")) 

278 .join(Page, Page.id == PageVersion.page_id) 

279 .where(Page.type == PageType.main_page) 

280 .group_by(PageVersion.page_id) 

281 .subquery() 

282 ) 

283 

284 clusters = execute_search_statement( 

285 session, 

286 select(Cluster, rank, snippet) 

287 .join(Page, Page.owner_cluster_id == Cluster.id) 

288 .join(PageVersion, PageVersion.page_id == Page.id) 

289 .join(latest_pages, latest_pages.c.id == PageVersion.id) 

290 .where(Cluster.is_official_cluster if include_communities and not include_groups else True) 

291 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True), 

292 ) 

293 

294 return [ 

295 search_pb2.Result( 

296 rank=rank, 

297 community=( 

298 community_to_pb(cluster.official_cluster_for_node, context) if cluster.is_official_cluster else None 

299 ), 

300 group=group_to_pb(cluster, context) if not cluster.is_official_cluster else None, 

301 snippet=snippet, 

302 ) 

303 for cluster, rank, snippet in clusters 

304 ] 

305 

306 

307class Search(search_pb2_grpc.SearchServicer): 

308 def Search(self, request, context): 

309 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

310 # this is not an ideal page token, some results have equal rank (unlikely) 

311 next_rank = float(request.page_token) if request.page_token else None 

312 with session_scope() as session: 

313 all_results = ( 

314 _search_users( 

315 session, 

316 request.query, 

317 request.title_only, 

318 next_rank, 

319 page_size, 

320 context, 

321 request.include_users, 

322 ) 

323 + _search_pages( 

324 session, 

325 request.query, 

326 request.title_only, 

327 next_rank, 

328 page_size, 

329 context, 

330 request.include_places, 

331 request.include_guides, 

332 ) 

333 + _search_events( 

334 session, 

335 request.query, 

336 request.title_only, 

337 next_rank, 

338 page_size, 

339 context, 

340 ) 

341 + _search_clusters( 

342 session, 

343 request.query, 

344 request.title_only, 

345 next_rank, 

346 page_size, 

347 context, 

348 request.include_communities, 

349 request.include_groups, 

350 ) 

351 ) 

352 all_results.sort(key=lambda result: result.rank, reverse=True) 

353 return search_pb2.SearchRes( 

354 results=all_results[:page_size], 

355 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None, 

356 ) 

357 

358 def UserSearch(self, request, context): 

359 with session_scope() as session: 

360 statement = select(User).where_users_visible(context) 

361 if request.HasField("query"): 

362 if request.query_name_only: 

363 statement = statement.where( 

364 or_( 

365 User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%") 

366 ) 

367 ) 

368 else: 

369 statement = statement.where( 

370 or_( 

371 User.name.ilike(f"%{request.query.value}%"), 

372 User.username.ilike(f"%{request.query.value}%"), 

373 User.city.ilike(f"%{request.query.value}%"), 

374 User.hometown.ilike(f"%{request.query.value}%"), 

375 User.about_me.ilike(f"%{request.query.value}%"), 

376 User.my_travels.ilike(f"%{request.query.value}%"), 

377 User.things_i_like.ilike(f"%{request.query.value}%"), 

378 User.about_place.ilike(f"%{request.query.value}%"), 

379 User.additional_information.ilike(f"%{request.query.value}%"), 

380 ) 

381 ) 

382 # if request.profile_completed: 

383 # statement = statement.where(User.has_completed_profile == True) 

384 

385 if request.HasField("last_active"): 

386 raw_dt = to_aware_datetime(request.last_active) 

387 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt)) 

388 

389 if request.HasField("gender"): 

390 statement = statement.where(User.gender.ilike(f"%{request.gender.value}%")) 

391 

392 if len(request.hosting_status_filter) > 0: 

393 statement = statement.where( 

394 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter]) 

395 ) 

396 if len(request.smoking_location_filter) > 0: 

397 statement = statement.where( 

398 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter]) 

399 ) 

400 if len(request.sleeping_arrangement_filter) > 0: 

401 statement = statement.where( 

402 User.sleeping_arrangement.in_( 

403 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter] 

404 ) 

405 ) 

406 if len(request.parking_details_filter) > 0: 

407 statement = statement.where( 

408 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter]) 

409 ) 

410 if request.HasField("profile_completed"): 

411 statement = statement.where(User.has_completed_profile == request.profile_completed.value) 

412 if request.HasField("guests"): 

413 statement = statement.where(User.max_guests >= request.guests.value) 

414 if request.HasField("last_minute"): 

415 statement = statement.where(User.last_minute == request.last_minute.value) 

416 if request.HasField("has_pets"): 

417 statement = statement.where(User.has_pets == request.has_pets.value) 

418 if request.HasField("accepts_pets"): 

419 statement = statement.where(User.accepts_pets == request.accepts_pets.value) 

420 if request.HasField("has_kids"): 

421 statement = statement.where(User.has_kids == request.has_kids.value) 

422 if request.HasField("accepts_kids"): 

423 statement = statement.where(User.accepts_kids == request.accepts_kids.value) 

424 if request.HasField("has_housemates"): 

425 statement = statement.where(User.has_housemates == request.has_housemates.value) 

426 if request.HasField("wheelchair_accessible"): 

427 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value) 

428 if request.HasField("smokes_at_home"): 

429 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value) 

430 if request.HasField("drinking_allowed"): 

431 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value) 

432 if request.HasField("drinks_at_home"): 

433 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value) 

434 if request.HasField("parking"): 

435 statement = statement.where(User.parking == request.parking.value) 

436 if request.HasField("camping_ok"): 

437 statement = statement.where(User.camping_ok == request.camping_ok.value) 

438 

439 if request.HasField("search_in_area"): 

440 # EPSG4326 measures distance in decimal degress 

441 # we want to check whether two circles overlap, so check if the distance between their centers is less 

442 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

443 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

444 statement = statement.where( 

445 func.ST_DWithin( 

446 # old: 

447 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

448 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

449 User.geom, 

450 search_point, 

451 (1000 + request.search_in_area.radius) / 111111, 

452 ) 

453 ) 

454 if request.HasField("search_in_community_id"): 

455 # could do a join here as well, but this is just simpler 

456 node = session.execute( 

457 select(Node).where(Node.id == request.search_in_community_id) 

458 ).scalar_one_or_none() 

459 if not node: 

460 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

461 statement = statement.where(func.ST_Contains(node.geom, User.geom)) 

462 

463 if request.only_with_references: 

464 statement = statement.join(Reference, Reference.to_user_id == User.id) 

465 

466 # TODO: 

467 # google.protobuf.StringValue language = 11; 

468 # bool friends_only = 13; 

469 # google.protobuf.UInt32Value age_min = 14; 

470 # google.protobuf.UInt32Value age_max = 15; 

471 

472 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

473 next_user_id = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10 

474 

475 statement = ( 

476 statement.where(User.recommendation_score <= next_user_id) 

477 .order_by(User.recommendation_score.desc()) 

478 .limit(page_size + 1) 

479 ) 

480 users = session.execute(statement).scalars().all() 

481 

482 return search_pb2.UserSearchRes( 

483 results=[ 

484 search_pb2.Result( 

485 rank=1, 

486 user=user_model_to_pb(user, session, context), 

487 ) 

488 for user in users[:page_size] 

489 ], 

490 next_page_token=( 

491 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None 

492 ), 

493 )