Coverage for app / backend / src / couchers / interceptors.py: 86%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-03 06:18 +0000

1import logging 

2from collections.abc import Callable, Mapping 

3from copy import deepcopy 

4from dataclasses import dataclass, field 

5from datetime import datetime, timedelta 

6from os import getpid 

7from threading import get_ident 

8from time import perf_counter_ns 

9from traceback import format_exception 

10from typing import Any, NoReturn, cast, overload 

11 

12import grpc 

13import sentry_sdk 

14from google.protobuf.descriptor import ServiceDescriptor 

15from google.protobuf.descriptor_pool import DescriptorPool 

16from google.protobuf.message import Message 

17from opentelemetry import trace 

18from sqlalchemy import Function, select 

19from sqlalchemy.sql import and_, func 

20 

21from couchers.constants import ( 

22 CALL_CANCELLED_ERROR_MESSAGE, 

23 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE, 

24 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

25 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

26 PERMISSION_DENIED_ERROR_MESSAGE, 

27 UNAUTHORIZED_ERROR_MESSAGE, 

28 UNKNOWN_ERROR_MESSAGE, 

29) 

30from couchers.context import CouchersContext, make_interactive_context, make_media_context 

31from couchers.db import session_scope 

32from couchers.descriptor_pool import get_descriptor_pool 

33from couchers.metrics import observe_in_servicer_duration_histogram 

34from couchers.models import APICall, User, UserActivity, UserSession 

35from couchers.proto import annotations_pb2 

36from couchers.proto.annotations_pb2 import AuthLevel 

37from couchers.utils import ( 

38 create_lang_cookie, 

39 create_session_cookies, 

40 generate_sofa_cookie, 

41 now, 

42 parse_api_key, 

43 parse_session_cookie, 

44 parse_sofa_cookie, 

45 parse_ui_lang_cookie, 

46 parse_user_id_cookie, 

47) 

48 

49logger = logging.getLogger(__name__) 

50 

51 

52@dataclass(frozen=True, slots=True, kw_only=True) 

53class UserAuthInfo: 

54 """Information about an authenticated user session.""" 

55 

56 user_id: int 

57 is_jailed: bool 

58 is_editor: bool 

59 is_superuser: bool 

60 token_expiry: datetime 

61 ui_language_preference: str | None 

62 token: str = field(repr=False) 

63 is_api_key: bool 

64 

65 

66def _binned_now() -> Function[Any]: 

67 return func.date_bin("1 hour", func.now(), "2000-01-01") 

68 

69 

70def _try_get_and_update_user_details( 

71 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None 

72) -> UserAuthInfo | None: 

73 """ 

74 Tries to get session and user info corresponding to this token. 

75 

76 Also updates the user's last active time, token last active time, and increments API call count. 

77 

78 Returns UserAuthInfo if a valid session is found, None otherwise. 

79 """ 

80 if not token: 

81 return None 

82 

83 with session_scope() as session: 

84 result = session.execute( 

85 select(User, UserSession, UserActivity) 

86 .select_from(UserSession) 

87 .join(User, User.id == UserSession.user_id) 

88 .outerjoin( 

89 UserActivity, 

90 and_( 

91 UserActivity.user_id == User.id, 

92 UserActivity.period == _binned_now(), 

93 UserActivity.ip_address == ip_address, 

94 UserActivity.user_agent == user_agent, 

95 ), 

96 ) 

97 .where(User.is_visible) 

98 .where(UserSession.token == token) 

99 .where(UserSession.is_valid) 

100 .where(UserSession.is_api_key == is_api_key) 

101 ).one_or_none() 

102 

103 if not result: 

104 return None 

105 

106 user, user_session, user_activity = result._tuple() 

107 

108 # update user last active time if it's been a while 

109 if now() - user.last_active > timedelta(minutes=5): 

110 user.last_active = func.now() 

111 

112 # let's update the token 

113 user_session.last_seen = func.now() 

114 user_session.api_calls += 1 

115 

116 if user_activity: 

117 user_activity.api_calls += 1 

118 else: 

119 session.add( 

120 UserActivity( 

121 user_id=user.id, 

122 period=_binned_now(), 

123 ip_address=ip_address, 

124 user_agent=user_agent, 

125 api_calls=1, 

126 ) 

127 ) 

128 

129 session.commit() 

130 

131 return UserAuthInfo( 

132 user_id=user.id, 

133 is_jailed=user.is_jailed, 

134 is_editor=user.is_editor, 

135 is_superuser=user.is_superuser, 

136 token_expiry=user_session.expiry, 

137 ui_language_preference=user.ui_language_preference, 

138 token=token, 

139 is_api_key=is_api_key, 

140 ) 

141 

142 

143def abort_handler[T, R]( 

144 message: str, 

145 status_code: grpc.StatusCode, 

146) -> grpc.RpcMethodHandler[T, R]: 

147 def f(request: Any, context: CouchersContext) -> NoReturn: 

148 context.abort(status_code, message) 

149 

150 return grpc.unary_unary_rpc_method_handler(f) 

151 

152 

153def unauthenticated_handler[T, R]( 

154 message: str = UNAUTHORIZED_ERROR_MESSAGE, 

155 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

156) -> grpc.RpcMethodHandler[T, R]: 

157 return abort_handler(message, status_code) 

158 

159 

160@overload 

161def _sanitized_bytes(proto: Message) -> bytes: ... 

162@overload 

163def _sanitized_bytes(proto: None) -> None: ... 

164def _sanitized_bytes(proto: Message | None) -> bytes | None: 

165 """ 

166 Remove fields marked sensitive and return serialized bytes 

167 """ 

168 if not proto: 

169 return None 

170 

171 new_proto = deepcopy(proto) 

172 

173 def _sanitize_message(message: Message) -> None: 

174 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

175 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

176 message.ClearField(name) 

177 if descriptor.message_type: 

178 submessage = getattr(message, name) 

179 if not submessage: 

180 continue 

181 if descriptor.is_repeated: 

182 for msg in submessage: 

183 _sanitize_message(msg) 

184 else: 

185 _sanitize_message(submessage) 

186 

187 _sanitize_message(new_proto) 

188 

189 return new_proto.SerializeToString() 

190 

191 

192def _store_log( 

193 *, 

194 method: str, 

195 status_code: str | None = None, 

196 duration: float, 

197 user_id: int | None, 

198 is_api_key: bool, 

199 request: Message, 

200 response: Message | None, 

201 traceback: str | None = None, 

202 perf_report: str | None = None, 

203 ip_address: str | None, 

204 user_agent: str | None, 

205 sofa: str | None, 

206) -> None: 

207 req_bytes = _sanitized_bytes(request) 

208 res_bytes = _sanitized_bytes(response) 

209 with session_scope() as session: 

210 response_truncated = False 

211 truncate_res_bytes_length = 16 * 1024 # 16 kB 

212 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 res_bytes = res_bytes[:truncate_res_bytes_length] 

214 response_truncated = True 

215 session.add( 

216 APICall( 

217 is_api_key=is_api_key, 

218 method=method, 

219 status_code=status_code, 

220 duration=duration, 

221 user_id=user_id, 

222 request=req_bytes, 

223 response=res_bytes, 

224 response_truncated=response_truncated, 

225 traceback=traceback, 

226 perf_report=perf_report, 

227 ip_address=ip_address, 

228 user_agent=user_agent, 

229 sofa=sofa, 

230 ) 

231 ) 

232 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

233 

234 

235type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

236 

237 

238class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

239 """ 

240 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

241 

242 Sets context.user_id and context.token if authenticated, otherwise 

243 terminates the call with an UNAUTHENTICATED error code. 

244 

245 2. Makes sure cookies are in sync. 

246 

247 3. Injects a session to get a database transaction. 

248 

249 4. Measures and logs the time it takes to service each incoming call. 

250 """ 

251 

252 def __init__(self) -> None: 

253 self._pool = get_descriptor_pool() 

254 

255 def intercept_service[T = Message, R = Message]( 

256 self, 

257 continuation: Cont[T, R], 

258 handler_call_details: grpc.HandlerCallDetails, 

259 ) -> grpc.RpcMethodHandler[T, R]: 

260 start = perf_counter_ns() 

261 

262 method = handler_call_details.method 

263 

264 try: 

265 auth_level = find_auth_level(self._pool, method) 

266 except AbortError as ae: 

267 return abort_handler(ae.msg, ae.code) 

268 

269 try: 

270 headers = parse_headers(dict(handler_call_details.invocation_metadata)) 

271 except BadHeaders: 

272 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE) 

273 

274 auth_info = _try_get_and_update_user_details( 

275 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent 

276 ) 

277 

278 try: 

279 check_permissions(auth_info, auth_level) 

280 except AbortError as ae: 

281 return unauthenticated_handler(ae.msg, ae.code) 

282 

283 if not (handler := continuation(handler_call_details)): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 raise RuntimeError(f"No handler in '{method}'") 

285 

286 if not (prev_function := handler.unary_unary): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

288 

289 if headers.sofa: 

290 sofa = headers.sofa 

291 new_sofa_cookie = None 

292 else: 

293 sofa, new_sofa_cookie = generate_sofa_cookie() 

294 

295 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

296 couchers_context = make_interactive_context( 

297 grpc_context=grpc_context, 

298 user_id=auth_info.user_id if auth_info else None, 

299 is_api_key=auth_info.is_api_key if auth_info else False, 

300 token=auth_info.token if auth_info else None, 

301 ui_language_preference=(auth_info.ui_language_preference if auth_info else None) or headers.ui_lang, 

302 ) 

303 

304 with session_scope() as session: 

305 try: 

306 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

307 res = cast(Message, _res) 

308 finished = perf_counter_ns() 

309 duration = (finished - start) / 1e6 # ms 

310 _store_log( 

311 method=method, 

312 duration=duration, 

313 user_id=couchers_context._user_id, 

314 is_api_key=cast(bool, couchers_context._is_api_key), 

315 request=req, 

316 response=res, 

317 ip_address=headers.ip_address, 

318 user_agent=headers.user_agent, 

319 sofa=sofa, 

320 ) 

321 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

322 except Exception as e: 

323 finished = perf_counter_ns() 

324 duration = (finished - start) / 1e6 # ms 

325 

326 if couchers_context._grpc_context: 326 ↛ 330line 326 didn't jump to line 330 because the condition on line 326 was always true

327 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined] 

328 code = getattr(context_code, "name", None) 

329 else: 

330 code = None 

331 

332 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

333 _store_log( 

334 method=method, 

335 status_code=code, 

336 duration=duration, 

337 user_id=couchers_context._user_id, 

338 is_api_key=cast(bool, couchers_context._is_api_key), 

339 request=req, 

340 response=None, 

341 traceback=traceback, 

342 ip_address=headers.ip_address, 

343 user_agent=headers.user_agent, 

344 sofa=sofa, 

345 ) 

346 observe_in_servicer_duration_histogram( 

347 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

348 ) 

349 

350 if not code: 

351 sentry_sdk.set_tag("context", "servicer") 

352 sentry_sdk.set_tag("method", method) 

353 sentry_sdk.capture_exception(e) 

354 

355 raise e 

356 

357 if auth_info and not auth_info.is_api_key: 

358 # check the two cookies are in sync & that language preference cookie is correct 

359 if headers.user_id != str(auth_info.user_id): 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true

360 couchers_context.set_cookies( 

361 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry) 

362 ) 

363 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang: 

364 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference)) 

365 

366 if new_sofa_cookie: 

367 couchers_context.set_cookies([new_sofa_cookie]) 

368 

369 if not grpc_context.is_active(): 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true

370 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE) 

371 

372 couchers_context._send_cookies() 

373 

374 return res 

375 

376 return grpc.unary_unary_rpc_method_handler( 

377 function_without_couchers_stuff, 

378 request_deserializer=handler.request_deserializer, 

379 response_serializer=handler.response_serializer, 

380 ) 

381 

382 

383@dataclass(frozen=True, slots=True, kw_only=True) 

384class CouchersHeaders: 

385 token: str | None = field(repr=False) 

386 is_api_key: bool 

387 ip_address: str | None 

388 user_agent: str | None 

389 ui_lang: str | None 

390 user_id: str | None 

391 sofa: str | None 

392 

393 

394def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders: 

395 if "cookie" in headers and "authorization" in headers: 

396 # for security reasons, only one of "cookie" or "authorization" can be present 

397 raise BadHeaders("Both cookies and authorization are present in headers") 

398 elif "cookie" in headers: 

399 # the session token is passed in cookies, i.e., in the `cookie` header 

400 token, is_api_key = parse_session_cookie(headers), False 

401 elif "authorization" in headers: 

402 # the session token is passed in the `authorization` header 

403 token, is_api_key = parse_api_key(headers), True 

404 else: 

405 # no session found 

406 token, is_api_key = None, False 

407 

408 ip_address = headers.get("x-couchers-real-ip") 

409 user_agent = headers.get("user-agent") 

410 

411 ui_lang = parse_ui_lang_cookie(headers) 

412 user_id = parse_user_id_cookie(headers) 

413 sofa = parse_sofa_cookie(headers) 

414 

415 return CouchersHeaders( 

416 token=token, 

417 is_api_key=is_api_key, 

418 ip_address=ip_address if isinstance(ip_address, str) else None, 

419 user_agent=user_agent if isinstance(user_agent, str) else None, 

420 ui_lang=ui_lang, 

421 user_id=user_id, 

422 sofa=sofa, 

423 ) 

424 

425 

426class BadHeaders(Exception): 

427 pass 

428 

429 

430class AbortError(Exception): 

431 def __init__(self, msg: str, code: grpc.StatusCode): 

432 self.msg = msg 

433 self.code = code 

434 

435 

436def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType: 

437 # method is of the form "/org.couchers.api.core.API/GetUser" 

438 _, service_name, method_name = method.split("/") 

439 

440 try: 

441 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

442 service_options = service.GetOptions() 

443 except KeyError: 

444 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None 

445 

446 level = service_options.Extensions[annotations_pb2.auth_level] 

447 

448 validate_auth_level(level) 

449 

450 return level 

451 

452 

453def validate_auth_level(auth_level: AuthLevel.ValueType) -> None: 

454 # if unknown auth level, then it wasn't set and something's wrong 

455 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

456 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

457 

458 if auth_level not in { 458 ↛ 465line 458 didn't jump to line 465 because the condition on line 458 was never true

459 annotations_pb2.AUTH_LEVEL_OPEN, 

460 annotations_pb2.AUTH_LEVEL_JAILED, 

461 annotations_pb2.AUTH_LEVEL_SECURE, 

462 annotations_pb2.AUTH_LEVEL_EDITOR, 

463 annotations_pb2.AUTH_LEVEL_ADMIN, 

464 }: 

465 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

466 

467 

468def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None: 

469 if not auth_info: 

470 # if this isn't an open service, fail 

471 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

472 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

473 else: 

474 # a valid user session was found - check permissions 

475 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser: 

476 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

477 

478 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor: 

479 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

480 

481 # if the user is jailed and this isn't an open or jailed service, fail 

482 if auth_info.is_jailed and auth_level not in [ 

483 annotations_pb2.AUTH_LEVEL_OPEN, 

484 annotations_pb2.AUTH_LEVEL_JAILED, 

485 ]: 

486 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

487 

488 

489class MediaInterceptor(grpc.ServerInterceptor): 

490 """ 

491 Extracts an "Authorization: Bearer <hex>" header and calls the 

492 is_authorized function. Terminates the call with an HTTP error 

493 code if not authorized. 

494 

495 Also adds a session to called APIs. 

496 """ 

497 

498 def __init__(self, is_authorized: Callable[[str], bool]): 

499 self._is_authorized = is_authorized 

500 

501 def intercept_service[T, R]( 

502 self, 

503 continuation: Cont[T, R], 

504 handler_call_details: grpc.HandlerCallDetails, 

505 ) -> grpc.RpcMethodHandler[T, R]: 

506 handler = continuation(handler_call_details) 

507 if not handler: 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true

508 raise RuntimeError("No handler") 

509 

510 prev_func = handler.unary_unary 

511 if not prev_func: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true

512 raise RuntimeError(f"No prev_function, {handler}") 

513 

514 metadata = dict(handler_call_details.invocation_metadata) 

515 

516 token = parse_api_key(metadata) 

517 

518 if not token or not self._is_authorized(token): 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true

519 return unauthenticated_handler() 

520 

521 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

522 with session_scope() as session: 

523 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

524 

525 return grpc.unary_unary_rpc_method_handler( 

526 function_without_session, 

527 request_deserializer=handler.request_deserializer, 

528 response_serializer=handler.response_serializer, 

529 ) 

530 

531 

532class OTelInterceptor(grpc.ServerInterceptor): 

533 """ 

534 OpenTelemetry tracing 

535 """ 

536 

537 def __init__(self) -> None: 

538 self.tracer = trace.get_tracer(__name__) 

539 

540 def intercept_service[T, R]( 

541 self, 

542 continuation: Cont[T, R], 

543 handler_call_details: grpc.HandlerCallDetails, 

544 ) -> grpc.RpcMethodHandler[T, R]: 

545 handler = continuation(handler_call_details) 

546 if not handler: 

547 raise RuntimeError("No handler") 

548 

549 prev_func = handler.unary_unary 

550 if not prev_func: 

551 raise RuntimeError(f"No prev_function, {handler}") 

552 

553 method = handler_call_details.method 

554 

555 # method is of the form "/org.couchers.api.core.API/GetUser" 

556 _, service_name, method_name = method.split("/") 

557 

558 headers = dict(handler_call_details.invocation_metadata) 

559 

560 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

561 with self.tracer.start_as_current_span("handler") as rollspan: 

562 rollspan.set_attribute("rpc.method_full", method) 

563 rollspan.set_attribute("rpc.service", service_name) 

564 rollspan.set_attribute("rpc.method", method_name) 

565 

566 rollspan.set_attribute("rpc.thread", get_ident()) 

567 rollspan.set_attribute("rpc.pid", getpid()) 

568 

569 res = prev_func(request, context) 

570 

571 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

572 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

573 

574 return res 

575 

576 return grpc.unary_unary_rpc_method_handler( 

577 tracing_function, 

578 request_deserializer=handler.request_deserializer, 

579 response_serializer=handler.response_serializer, 

580 ) 

581 

582 

583class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

584 """ 

585 If the call resulted in a non-gRPC error, this strips away the error details. 

586 

587 It's important to put this first, so that it does not interfere with other interceptors. 

588 """ 

589 

590 def intercept_service[T, R]( 

591 self, 

592 continuation: Cont[T, R], 

593 handler_call_details: grpc.HandlerCallDetails, 

594 ) -> grpc.RpcMethodHandler[T, R]: 

595 handler = continuation(handler_call_details) 

596 if not handler: 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true

597 raise RuntimeError("No handler") 

598 

599 prev_func = handler.unary_unary 

600 if not prev_func: 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true

601 raise RuntimeError(f"No prev_function, {handler}") 

602 

603 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

604 try: 

605 res = prev_func(req, context) 

606 except Exception as e: 

607 code = context.code() # type: ignore[attr-defined] 

608 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

609 if not code: 

610 logger.exception(e) 

611 logger.info("Probably an unknown error! Sanitizing...") 

612 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

613 else: 

614 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

615 raise e 

616 return res 

617 

618 return grpc.unary_unary_rpc_method_handler( 

619 sanitizing_function, 

620 request_deserializer=handler.request_deserializer, 

621 response_serializer=handler.response_serializer, 

622 )