Coverage for src/couchers/servicers/search.py: 82%

173 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-07-22 16:44 +0000

1""" 

2See //docs/search.md for overview. 

3""" 

4 

5import grpc 

6from sqlalchemy.sql import func, or_ 

7 

8from couchers import errors 

9from couchers.crypto import decrypt_page_token, encrypt_page_token 

10from couchers.db import session_scope 

11from couchers.models import Cluster, Event, EventOccurrence, Node, Page, PageType, PageVersion, Reference, User 

12from couchers.servicers.api import ( 

13 hostingstatus2sql, 

14 meetupstatus2sql, 

15 parkingdetails2sql, 

16 sleepingarrangement2sql, 

17 smokinglocation2sql, 

18 user_model_to_pb, 

19) 

20from couchers.servicers.communities import community_to_pb 

21from couchers.servicers.events import event_to_pb 

22from couchers.servicers.groups import group_to_pb 

23from couchers.servicers.pages import page_to_pb 

24from couchers.sql import couchers_select as select 

25from couchers.utils import create_coordinate, last_active_coarsen, to_aware_datetime 

26from proto import search_pb2, search_pb2_grpc 

27 

28# searches are a bit expensive, we'd rather send back a bunch of results at once than lots of small pages 

29MAX_PAGINATION_LENGTH = 50 

30 

31REGCONFIG = "english" 

32TRI_SIMILARITY_THRESHOLD = 0.6 

33TRI_SIMILARITY_WEIGHT = 5 

34 

35 

36def _join_with_space(coalesces): 

37 # the objects in coalesces are not strings, so we can't do " ".join(coalesces). They're SQLAlchemy magic. 

38 if not coalesces: 

39 return "" 

40 out = coalesces[0] 

41 for coalesce in coalesces[1:]: 

42 out += " " + coalesce 

43 return out 

44 

45 

46def _build_tsv(A, B=None, C=None, D=None): 

47 """ 

48 Given lists for A, B, C, and D, builds a tsvector from them. 

49 """ 

50 B = B or [] 

51 C = C or [] 

52 D = D or [] 

53 tsv = func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in A])), "A") 

54 if B: 

55 tsv = tsv.concat( 

56 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in B])), "B") 

57 ) 

58 if C: 

59 tsv = tsv.concat( 

60 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in C])), "C") 

61 ) 

62 if D: 

63 tsv = tsv.concat( 

64 func.setweight(func.to_tsvector(REGCONFIG, _join_with_space([func.coalesce(bit, "") for bit in D])), "D") 

65 ) 

66 return tsv 

67 

68 

69def _build_doc(A, B=None, C=None, D=None): 

70 """ 

71 Builds the raw document (without to_tsvector and weighting), used for extracting snippet 

72 """ 

73 B = B or [] 

74 C = C or [] 

75 D = D or [] 

76 doc = _join_with_space([func.coalesce(bit, "") for bit in A]) 

77 if B: 

78 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in B]) 

79 if C: 

80 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in C]) 

81 if D: 

82 doc += " " + _join_with_space([func.coalesce(bit, "") for bit in D]) 

83 return doc 

84 

85 

86def _similarity(statement, text): 

87 return func.word_similarity(func.unaccent(statement), func.unaccent(text)) 

88 

89 

90def _gen_search_elements(statement, title_only, next_rank, page_size, A, B=None, C=None, D=None): 

91 """ 

92 Given an sql statement and four sets of fields, (A, B, C, D), generates a bunch of postgres expressions for full text search. 

93 

94 The four sets are in decreasing order of "importance" for ranking. 

95 

96 A should be the "title", the others can be anything. 

97 

98 If title_only=True, we only perform a trigram search against A only 

99 """ 

100 B = B or [] 

101 C = C or [] 

102 D = D or [] 

103 if not title_only: 

104 # a postgres tsquery object that can be used to match against a tsvector 

105 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

106 

107 # the tsvector object that we want to search against with our tsquery 

108 tsv = _build_tsv(A, B, C, D) 

109 

110 # document to generate snippet from 

111 doc = _build_doc(A, B, C, D) 

112 

113 title = _build_doc(A) 

114 

115 # trigram based text similarity between title and sql statement string 

116 sim = _similarity(statement, title) 

117 

118 # ranking algo, weigh the similarity a lot, the text-based ranking less 

119 rank = (TRI_SIMILARITY_WEIGHT * sim + func.ts_rank_cd(tsv, tsq)).label("rank") 

120 

121 # the snippet with results highlighted 

122 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

123 

124 def execute_search_statement(session, orig_statement): 

125 """ 

126 Does the right search filtering, limiting, and ordering for the initial statement 

127 """ 

128 return session.execute( 

129 orig_statement.where(or_(tsv.op("@@")(tsq), sim > TRI_SIMILARITY_THRESHOLD)) 

130 .where(rank <= next_rank if next_rank is not None else True) 

131 .order_by(rank.desc()) 

132 .limit(page_size + 1) 

133 ).all() 

134 

135 else: 

136 title = _build_doc(A) 

137 

138 # trigram based text similarity between title and sql statement string 

139 sim = _similarity(statement, title) 

140 

141 # ranking algo, weigh the similarity a lot, the text-based ranking less 

142 rank = sim.label("rank") 

143 

144 # used only for headline 

145 tsq = func.websearch_to_tsquery(REGCONFIG, statement) 

146 doc = _build_doc(A, B, C, D) 

147 

148 # the snippet with results highlighted 

149 snippet = func.ts_headline(REGCONFIG, doc, tsq, "StartSel=**,StopSel=**").label("snippet") 

150 

151 def execute_search_statement(session, orig_statement): 

152 """ 

153 Does the right search filtering, limiting, and ordering for the initial statement 

154 """ 

155 return session.execute( 

156 orig_statement.where(sim > TRI_SIMILARITY_THRESHOLD) 

157 .where(rank <= next_rank if next_rank is not None else True) 

158 .order_by(rank.desc()) 

159 .limit(page_size + 1) 

160 ).all() 

161 

162 return rank, snippet, execute_search_statement 

163 

164 

165def _search_users(session, search_statement, title_only, next_rank, page_size, context, include_users): 

166 if not include_users: 

167 return [] 

168 rank, snippet, execute_search_statement = _gen_search_elements( 

169 search_statement, 

170 title_only, 

171 next_rank, 

172 page_size, 

173 [User.username, User.name], 

174 [User.city], 

175 [User.about_me], 

176 [User.my_travels, User.things_i_like, User.about_place, User.additional_information], 

177 ) 

178 

179 users = execute_search_statement(session, select(User, rank, snippet).where_users_visible(context)) 

180 

181 return [ 

182 search_pb2.Result( 

183 rank=rank, 

184 user=user_model_to_pb(page, session, context), 

185 snippet=snippet, 

186 ) 

187 for page, rank, snippet in users 

188 ] 

189 

190 

191def _search_pages(session, search_statement, title_only, next_rank, page_size, context, include_places, include_guides): 

192 rank, snippet, execute_search_statement = _gen_search_elements( 

193 search_statement, 

194 title_only, 

195 next_rank, 

196 page_size, 

197 [PageVersion.title], 

198 [PageVersion.address], 

199 [], 

200 [PageVersion.content], 

201 ) 

202 if not include_places and not include_guides: 

203 return [] 

204 

205 latest_pages = ( 

206 select(func.max(PageVersion.id).label("id")) 

207 .join(Page, Page.id == PageVersion.page_id) 

208 .where( 

209 or_( 

210 (Page.type == PageType.place) if include_places else False, 

211 (Page.type == PageType.guide) if include_guides else False, 

212 ) 

213 ) 

214 .group_by(PageVersion.page_id) 

215 .subquery() 

216 ) 

217 

218 pages = execute_search_statement( 

219 session, 

220 select(Page, rank, snippet) 

221 .join(PageVersion, PageVersion.page_id == Page.id) 

222 .join(latest_pages, latest_pages.c.id == PageVersion.id), 

223 ) 

224 

225 return [ 

226 search_pb2.Result( 

227 rank=rank, 

228 place=page_to_pb(page, context) if page.type == PageType.place else None, 

229 guide=page_to_pb(page, context) if page.type == PageType.guide else None, 

230 snippet=snippet, 

231 ) 

232 for page, rank, snippet in pages 

233 ] 

234 

235 

236def _search_events(session, search_statement, title_only, next_rank, page_size, context): 

237 rank, snippet, execute_search_statement = _gen_search_elements( 

238 search_statement, 

239 title_only, 

240 next_rank, 

241 page_size, 

242 [Event.title], 

243 [EventOccurrence.address, EventOccurrence.link], 

244 [], 

245 [EventOccurrence.content], 

246 ) 

247 

248 occurrences = execute_search_statement( 

249 session, 

250 select(EventOccurrence, rank, snippet) 

251 .join(Event, Event.id == EventOccurrence.event_id) 

252 .where(EventOccurrence.end_time >= func.now()), 

253 ) 

254 

255 return [ 

256 search_pb2.Result( 

257 rank=rank, 

258 event=event_to_pb(session, occurrence, context), 

259 snippet=snippet, 

260 ) 

261 for occurrence, rank, snippet in occurrences 

262 ] 

263 

264 

265def _search_clusters( 

266 session, search_statement, title_only, next_rank, page_size, context, include_communities, include_groups 

267): 

268 if not include_communities and not include_groups: 

269 return [] 

270 

271 rank, snippet, execute_search_statement = _gen_search_elements( 

272 search_statement, 

273 title_only, 

274 next_rank, 

275 page_size, 

276 [Cluster.name], 

277 [PageVersion.address, PageVersion.title], 

278 [Cluster.description], 

279 [PageVersion.content], 

280 ) 

281 

282 latest_pages = ( 

283 select(func.max(PageVersion.id).label("id")) 

284 .join(Page, Page.id == PageVersion.page_id) 

285 .where(Page.type == PageType.main_page) 

286 .group_by(PageVersion.page_id) 

287 .subquery() 

288 ) 

289 

290 clusters = execute_search_statement( 

291 session, 

292 select(Cluster, rank, snippet) 

293 .join(Page, Page.owner_cluster_id == Cluster.id) 

294 .join(PageVersion, PageVersion.page_id == Page.id) 

295 .join(latest_pages, latest_pages.c.id == PageVersion.id) 

296 .where(Cluster.is_official_cluster if include_communities and not include_groups else True) 

297 .where(~Cluster.is_official_cluster if not include_communities and include_groups else True), 

298 ) 

299 

300 return [ 

301 search_pb2.Result( 

302 rank=rank, 

303 community=( 

304 community_to_pb(cluster.official_cluster_for_node, context) if cluster.is_official_cluster else None 

305 ), 

306 group=group_to_pb(cluster, context) if not cluster.is_official_cluster else None, 

307 snippet=snippet, 

308 ) 

309 for cluster, rank, snippet in clusters 

310 ] 

311 

312 

313class Search(search_pb2_grpc.SearchServicer): 

314 def Search(self, request, context): 

315 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

316 # this is not an ideal page token, some results have equal rank (unlikely) 

317 next_rank = float(request.page_token) if request.page_token else None 

318 with session_scope() as session: 

319 all_results = ( 

320 _search_users( 

321 session, 

322 request.query, 

323 request.title_only, 

324 next_rank, 

325 page_size, 

326 context, 

327 request.include_users, 

328 ) 

329 + _search_pages( 

330 session, 

331 request.query, 

332 request.title_only, 

333 next_rank, 

334 page_size, 

335 context, 

336 request.include_places, 

337 request.include_guides, 

338 ) 

339 + _search_events( 

340 session, 

341 request.query, 

342 request.title_only, 

343 next_rank, 

344 page_size, 

345 context, 

346 ) 

347 + _search_clusters( 

348 session, 

349 request.query, 

350 request.title_only, 

351 next_rank, 

352 page_size, 

353 context, 

354 request.include_communities, 

355 request.include_groups, 

356 ) 

357 ) 

358 all_results.sort(key=lambda result: result.rank, reverse=True) 

359 return search_pb2.SearchRes( 

360 results=all_results[:page_size], 

361 next_page_token=str(all_results[page_size].rank) if len(all_results) > page_size else None, 

362 ) 

363 

364 def UserSearch(self, request, context): 

365 with session_scope() as session: 

366 statement = select(User).where_users_visible(context) 

367 if request.HasField("query"): 

368 if request.query_name_only: 

369 statement = statement.where( 

370 or_( 

371 User.name.ilike(f"%{request.query.value}%"), User.username.ilike(f"%{request.query.value}%") 

372 ) 

373 ) 

374 else: 

375 statement = statement.where( 

376 or_( 

377 User.name.ilike(f"%{request.query.value}%"), 

378 User.username.ilike(f"%{request.query.value}%"), 

379 User.city.ilike(f"%{request.query.value}%"), 

380 User.hometown.ilike(f"%{request.query.value}%"), 

381 User.about_me.ilike(f"%{request.query.value}%"), 

382 User.my_travels.ilike(f"%{request.query.value}%"), 

383 User.things_i_like.ilike(f"%{request.query.value}%"), 

384 User.about_place.ilike(f"%{request.query.value}%"), 

385 User.additional_information.ilike(f"%{request.query.value}%"), 

386 ) 

387 ) 

388 # if request.profile_completed: 

389 # statement = statement.where(User.has_completed_profile == True) 

390 

391 if request.HasField("last_active"): 

392 raw_dt = to_aware_datetime(request.last_active) 

393 statement = statement.where(User.last_active >= last_active_coarsen(raw_dt)) 

394 

395 if request.HasField("gender"): 

396 statement = statement.where(User.gender.ilike(f"%{request.gender.value}%")) 

397 

398 if len(request.hosting_status_filter) > 0: 

399 statement = statement.where( 

400 User.hosting_status.in_([hostingstatus2sql[status] for status in request.hosting_status_filter]) 

401 ) 

402 if len(request.meetup_status_filter) > 0: 

403 statement = statement.where( 

404 User.meetup_status.in_([meetupstatus2sql[status] for status in request.meetup_status_filter]) 

405 ) 

406 if len(request.smoking_location_filter) > 0: 

407 statement = statement.where( 

408 User.smoking_allowed.in_([smokinglocation2sql[loc] for loc in request.smoking_location_filter]) 

409 ) 

410 if len(request.sleeping_arrangement_filter) > 0: 

411 statement = statement.where( 

412 User.sleeping_arrangement.in_( 

413 [sleepingarrangement2sql[arr] for arr in request.sleeping_arrangement_filter] 

414 ) 

415 ) 

416 if len(request.parking_details_filter) > 0: 

417 statement = statement.where( 

418 User.parking_details.in_([parkingdetails2sql[det] for det in request.parking_details_filter]) 

419 ) 

420 if request.HasField("profile_completed"): 

421 statement = statement.where(User.has_completed_profile == request.profile_completed.value) 

422 if request.HasField("guests"): 

423 statement = statement.where(User.max_guests >= request.guests.value) 

424 if request.HasField("last_minute"): 

425 statement = statement.where(User.last_minute == request.last_minute.value) 

426 if request.HasField("has_pets"): 

427 statement = statement.where(User.has_pets == request.has_pets.value) 

428 if request.HasField("accepts_pets"): 

429 statement = statement.where(User.accepts_pets == request.accepts_pets.value) 

430 if request.HasField("has_kids"): 

431 statement = statement.where(User.has_kids == request.has_kids.value) 

432 if request.HasField("accepts_kids"): 

433 statement = statement.where(User.accepts_kids == request.accepts_kids.value) 

434 if request.HasField("has_housemates"): 

435 statement = statement.where(User.has_housemates == request.has_housemates.value) 

436 if request.HasField("wheelchair_accessible"): 

437 statement = statement.where(User.wheelchair_accessible == request.wheelchair_accessible.value) 

438 if request.HasField("smokes_at_home"): 

439 statement = statement.where(User.smokes_at_home == request.smokes_at_home.value) 

440 if request.HasField("drinking_allowed"): 

441 statement = statement.where(User.drinking_allowed == request.drinking_allowed.value) 

442 if request.HasField("drinks_at_home"): 

443 statement = statement.where(User.drinks_at_home == request.drinks_at_home.value) 

444 if request.HasField("parking"): 

445 statement = statement.where(User.parking == request.parking.value) 

446 if request.HasField("camping_ok"): 

447 statement = statement.where(User.camping_ok == request.camping_ok.value) 

448 

449 if request.HasField("search_in_area"): 

450 # EPSG4326 measures distance in decimal degress 

451 # we want to check whether two circles overlap, so check if the distance between their centers is less 

452 # than the sum of their radii, divided by 111111 m ~= 1 degree (at the equator) 

453 search_point = create_coordinate(request.search_in_area.lat, request.search_in_area.lng) 

454 statement = statement.where( 

455 func.ST_DWithin( 

456 # old: 

457 # User.geom, search_point, (User.geom_radius + request.search_in_area.radius) / 111111 

458 # this is an optimization that speeds up the db queries since it doesn't need to look up the user's geom radius 

459 User.geom, 

460 search_point, 

461 (1000 + request.search_in_area.radius) / 111111, 

462 ) 

463 ) 

464 if request.HasField("search_in_rectangle"): 

465 statement = statement.where( 

466 func.ST_Within( 

467 User.geom, 

468 func.ST_MakeEnvelope( 

469 request.search_in_rectangle.lng_min, 

470 request.search_in_rectangle.lat_min, 

471 request.search_in_rectangle.lng_max, 

472 request.search_in_rectangle.lat_max, 

473 4326, 

474 ), 

475 ) 

476 ) 

477 if request.HasField("search_in_community_id"): 

478 # could do a join here as well, but this is just simpler 

479 node = session.execute( 

480 select(Node).where(Node.id == request.search_in_community_id) 

481 ).scalar_one_or_none() 

482 if not node: 

483 context.abort(grpc.StatusCode.NOT_FOUND, errors.COMMUNITY_NOT_FOUND) 

484 statement = statement.where(func.ST_Contains(node.geom, User.geom)) 

485 

486 if request.only_with_references: 

487 statement = statement.join(Reference, Reference.to_user_id == User.id) 

488 

489 # TODO: 

490 # google.protobuf.StringValue language = 11; 

491 # bool friends_only = 13; 

492 # google.protobuf.UInt32Value age_min = 14; 

493 # google.protobuf.UInt32Value age_max = 15; 

494 

495 page_size = min(MAX_PAGINATION_LENGTH, request.page_size or MAX_PAGINATION_LENGTH) 

496 next_recommendation_score = float(decrypt_page_token(request.page_token)) if request.page_token else 1e10 

497 

498 statement = ( 

499 statement.where(User.recommendation_score <= next_recommendation_score) 

500 .order_by(User.recommendation_score.desc()) 

501 .limit(page_size + 1) 

502 ) 

503 users = session.execute(statement).scalars().all() 

504 

505 return search_pb2.UserSearchRes( 

506 results=[ 

507 search_pb2.Result( 

508 rank=1, 

509 user=user_model_to_pb(user, session, context), 

510 ) 

511 for user in users[:page_size] 

512 ], 

513 next_page_token=( 

514 encrypt_page_token(str(users[-1].recommendation_score)) if len(users) > page_size else None 

515 ), 

516 )