Coverage for src / couchers / interceptors.py: 85%

255 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-13 12:05 +0000

1import logging 

2from collections.abc import Callable, Mapping 

3from copy import deepcopy 

4from dataclasses import dataclass, field 

5from datetime import datetime, timedelta 

6from os import getpid 

7from threading import get_ident 

8from time import perf_counter_ns 

9from traceback import format_exception 

10from typing import Any, NoReturn, cast, overload 

11 

12import grpc 

13import sentry_sdk 

14from google.protobuf.descriptor import ServiceDescriptor 

15from google.protobuf.descriptor_pool import DescriptorPool 

16from google.protobuf.message import Message 

17from opentelemetry import trace 

18from sqlalchemy import Function, select 

19from sqlalchemy.sql import and_, func 

20 

21from couchers.constants import ( 

22 CALL_CANCELLED_ERROR_MESSAGE, 

23 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE, 

24 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

25 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

26 PERMISSION_DENIED_ERROR_MESSAGE, 

27 UNAUTHORIZED_ERROR_MESSAGE, 

28 UNKNOWN_ERROR_MESSAGE, 

29) 

30from couchers.context import CouchersContext, make_interactive_context, make_media_context 

31from couchers.db import session_scope 

32from couchers.descriptor_pool import get_descriptor_pool 

33from couchers.metrics import observe_in_servicer_duration_histogram 

34from couchers.models import APICall, User, UserActivity, UserSession 

35from couchers.proto import annotations_pb2 

36from couchers.proto.annotations_pb2 import AuthLevel 

37from couchers.utils import ( 

38 create_lang_cookie, 

39 create_session_cookies, 

40 now, 

41 parse_api_key, 

42 parse_session_cookie, 

43 parse_ui_lang_cookie, 

44 parse_user_id_cookie, 

45) 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50@dataclass(frozen=True, slots=True, kw_only=True) 

51class UserAuthInfo: 

52 """Information about an authenticated user session.""" 

53 

54 user_id: int 

55 is_jailed: bool 

56 is_editor: bool 

57 is_superuser: bool 

58 token_expiry: datetime 

59 ui_language_preference: str | None 

60 token: str = field(repr=False) 

61 is_api_key: bool 

62 

63 

64def _binned_now() -> Function[Any]: 

65 return func.date_bin("1 hour", func.now(), "2000-01-01") 

66 

67 

68def _try_get_and_update_user_details( 

69 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None 

70) -> UserAuthInfo | None: 

71 """ 

72 Tries to get session and user info corresponding to this token. 

73 

74 Also updates the user's last active time, token last active time, and increments API call count. 

75 

76 Returns UserAuthInfo if a valid session is found, None otherwise. 

77 """ 

78 if not token: 

79 return None 

80 

81 with session_scope() as session: 

82 result = session.execute( 

83 select(User, UserSession, UserActivity) 

84 .select_from(UserSession) 

85 .join(User, User.id == UserSession.user_id) 

86 .outerjoin( 

87 UserActivity, 

88 and_( 

89 UserActivity.user_id == User.id, 

90 UserActivity.period == _binned_now(), 

91 UserActivity.ip_address == ip_address, 

92 UserActivity.user_agent == user_agent, 

93 ), 

94 ) 

95 .where(User.is_visible) 

96 .where(UserSession.token == token) 

97 .where(UserSession.is_valid) 

98 .where(UserSession.is_api_key == is_api_key) 

99 ).one_or_none() 

100 

101 if not result: 

102 return None 

103 

104 user, user_session, user_activity = result._tuple() 

105 

106 # update user last active time if it's been a while 

107 if now() - user.last_active > timedelta(minutes=5): 

108 user.last_active = func.now() 

109 

110 # let's update the token 

111 user_session.last_seen = func.now() 

112 user_session.api_calls += 1 

113 

114 if user_activity: 

115 user_activity.api_calls += 1 

116 else: 

117 session.add( 

118 UserActivity( 

119 user_id=user.id, 

120 period=_binned_now(), 

121 ip_address=ip_address, 

122 user_agent=user_agent, 

123 api_calls=1, 

124 ) 

125 ) 

126 

127 session.commit() 

128 

129 return UserAuthInfo( 

130 user_id=user.id, 

131 is_jailed=user.is_jailed, 

132 is_editor=user.is_editor, 

133 is_superuser=user.is_superuser, 

134 token_expiry=user_session.expiry, 

135 ui_language_preference=user.ui_language_preference, 

136 token=token, 

137 is_api_key=is_api_key, 

138 ) 

139 

140 

141def abort_handler[T, R]( 

142 message: str, 

143 status_code: grpc.StatusCode, 

144) -> grpc.RpcMethodHandler[T, R]: 

145 def f(request: Any, context: CouchersContext) -> NoReturn: 

146 context.abort(status_code, message) 

147 

148 return grpc.unary_unary_rpc_method_handler(f) 

149 

150 

151def unauthenticated_handler[T, R]( 

152 message: str = UNAUTHORIZED_ERROR_MESSAGE, 

153 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

154) -> grpc.RpcMethodHandler[T, R]: 

155 return abort_handler(message, status_code) 

156 

157 

158@overload 

159def _sanitized_bytes(proto: Message) -> bytes: ... 

160@overload 

161def _sanitized_bytes(proto: None) -> None: ... 

162def _sanitized_bytes(proto: Message | None) -> bytes | None: 

163 """ 

164 Remove fields marked sensitive and return serialized bytes 

165 """ 

166 if not proto: 

167 return None 

168 

169 new_proto = deepcopy(proto) 

170 

171 def _sanitize_message(message: Message) -> None: 

172 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

173 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

174 message.ClearField(name) 

175 if descriptor.message_type: 

176 submessage = getattr(message, name) 

177 if not submessage: 

178 continue 

179 if descriptor.is_repeated: 

180 for msg in submessage: 

181 _sanitize_message(msg) 

182 else: 

183 _sanitize_message(submessage) 

184 

185 _sanitize_message(new_proto) 

186 

187 return new_proto.SerializeToString() 

188 

189 

190def _store_log( 

191 *, 

192 method: str, 

193 status_code: str | None = None, 

194 duration: float, 

195 user_id: int | None, 

196 is_api_key: bool, 

197 request: Message, 

198 response: Message | None, 

199 traceback: str | None = None, 

200 perf_report: str | None = None, 

201 ip_address: str | None, 

202 user_agent: str | None, 

203) -> None: 

204 req_bytes = _sanitized_bytes(request) 

205 res_bytes = _sanitized_bytes(response) 

206 with session_scope() as session: 

207 response_truncated = False 

208 truncate_res_bytes_length = 16 * 1024 # 16 kB 

209 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 res_bytes = res_bytes[:truncate_res_bytes_length] 

211 response_truncated = True 

212 session.add( 

213 APICall( 

214 is_api_key=is_api_key, 

215 method=method, 

216 status_code=status_code, 

217 duration=duration, 

218 user_id=user_id, 

219 request=req_bytes, 

220 response=res_bytes, 

221 response_truncated=response_truncated, 

222 traceback=traceback, 

223 perf_report=perf_report, 

224 ip_address=ip_address, 

225 user_agent=user_agent, 

226 ) 

227 ) 

228 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

229 

230 

231type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

232 

233 

234class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

235 """ 

236 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

237 

238 Sets context.user_id and context.token if authenticated, otherwise 

239 terminates the call with an UNAUTHENTICATED error code. 

240 

241 2. Makes sure cookies are in sync. 

242 

243 3. Injects a session to get a database transaction. 

244 

245 4. Measures and logs the time it takes to service each incoming call. 

246 """ 

247 

248 def __init__(self) -> None: 

249 self._pool = get_descriptor_pool() 

250 

251 def intercept_service[T = Message, R = Message]( 

252 self, 

253 continuation: Cont[T, R], 

254 handler_call_details: grpc.HandlerCallDetails, 

255 ) -> grpc.RpcMethodHandler[T, R]: 

256 start = perf_counter_ns() 

257 

258 method = handler_call_details.method 

259 

260 try: 

261 auth_level = find_auth_level(self._pool, method) 

262 except AbortError as ae: 

263 return abort_handler(ae.msg, ae.code) 

264 

265 try: 

266 headers = parse_headers(dict(handler_call_details.invocation_metadata)) 

267 except BadHeaders: 

268 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE) 

269 

270 auth_info = _try_get_and_update_user_details( 

271 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent 

272 ) 

273 

274 try: 

275 check_permissions(auth_info, auth_level) 

276 except AbortError as ae: 

277 return unauthenticated_handler(ae.msg, ae.code) 

278 

279 if not (handler := continuation(handler_call_details)): 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 raise RuntimeError(f"No handler in '{method}'") 

281 

282 if not (prev_function := handler.unary_unary): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true

283 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

284 

285 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

286 couchers_context = make_interactive_context( 

287 grpc_context=grpc_context, 

288 user_id=auth_info.user_id if auth_info else None, 

289 is_api_key=auth_info.is_api_key if auth_info else False, 

290 token=auth_info.token if auth_info else None, 

291 ui_language_preference=auth_info.ui_language_preference if auth_info else None, 

292 ) 

293 

294 with session_scope() as session: 

295 try: 

296 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

297 res = cast(Message, _res) 

298 finished = perf_counter_ns() 

299 duration = (finished - start) / 1e6 # ms 

300 _store_log( 

301 method=method, 

302 duration=duration, 

303 user_id=couchers_context._user_id, 

304 is_api_key=cast(bool, couchers_context._is_api_key), 

305 request=req, 

306 response=res, 

307 ip_address=headers.ip_address, 

308 user_agent=headers.user_agent, 

309 ) 

310 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

311 except Exception as e: 

312 finished = perf_counter_ns() 

313 duration = (finished - start) / 1e6 # ms 

314 

315 if couchers_context._grpc_context: 315 ↛ 319line 315 didn't jump to line 319 because the condition on line 315 was always true

316 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined] 

317 code = getattr(context_code, "name", None) 

318 else: 

319 code = None 

320 

321 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

322 _store_log( 

323 method=method, 

324 status_code=code, 

325 duration=duration, 

326 user_id=couchers_context._user_id, 

327 is_api_key=cast(bool, couchers_context._is_api_key), 

328 request=req, 

329 response=None, 

330 traceback=traceback, 

331 ip_address=headers.ip_address, 

332 user_agent=headers.user_agent, 

333 ) 

334 observe_in_servicer_duration_histogram( 

335 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

336 ) 

337 

338 if not code: 

339 sentry_sdk.set_tag("context", "servicer") 

340 sentry_sdk.set_tag("method", method) 

341 sentry_sdk.capture_exception(e) 

342 

343 raise e 

344 

345 if auth_info and not auth_info.is_api_key: 

346 # check the two cookies are in sync & that language preference cookie is correct 

347 if headers.user_id != str(auth_info.user_id): 347 ↛ 351line 347 didn't jump to line 351 because the condition on line 347 was always true

348 couchers_context.set_cookies( 

349 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry) 

350 ) 

351 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang: 

352 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference)) 

353 

354 if not grpc_context.is_active(): 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE) 

356 

357 couchers_context._send_cookies() 

358 

359 return res 

360 

361 return grpc.unary_unary_rpc_method_handler( 

362 function_without_couchers_stuff, 

363 request_deserializer=handler.request_deserializer, 

364 response_serializer=handler.response_serializer, 

365 ) 

366 

367 

368@dataclass(frozen=True, slots=True, kw_only=True) 

369class CouchersHeaders: 

370 token: str | None = field(repr=False) 

371 is_api_key: bool 

372 ip_address: str | None 

373 user_agent: str | None 

374 ui_lang: str | None 

375 user_id: str | None 

376 

377 

378def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders: 

379 if "cookie" in headers and "authorization" in headers: 

380 # for security reasons, only one of "cookie" or "authorization" can be present 

381 raise BadHeaders("Both cookies and authorization are present in headers") 

382 elif "cookie" in headers: 

383 # the session token is passed in cookies, i.e., in the `cookie` header 

384 token, is_api_key = parse_session_cookie(headers), False 

385 elif "authorization" in headers: 

386 # the session token is passed in the `authorization` header 

387 token, is_api_key = parse_api_key(headers), True 

388 else: 

389 # no session found 

390 token, is_api_key = None, False 

391 

392 ip_address = headers.get("x-couchers-real-ip") 

393 user_agent = headers.get("user-agent") 

394 

395 ui_lang = parse_ui_lang_cookie(headers) 

396 user_id = parse_user_id_cookie(headers) 

397 

398 return CouchersHeaders( 

399 token=token, 

400 is_api_key=is_api_key, 

401 ip_address=ip_address if isinstance(ip_address, str) else None, 

402 user_agent=user_agent if isinstance(user_agent, str) else None, 

403 ui_lang=ui_lang, 

404 user_id=user_id, 

405 ) 

406 

407 

408class BadHeaders(Exception): 

409 pass 

410 

411 

412class AbortError(Exception): 

413 def __init__(self, msg: str, code: grpc.StatusCode): 

414 self.msg = msg 

415 self.code = code 

416 

417 

418def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType: 

419 # method is of the form "/org.couchers.api.core.API/GetUser" 

420 _, service_name, method_name = method.split("/") 

421 

422 try: 

423 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

424 service_options = service.GetOptions() 

425 except KeyError: 

426 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None 

427 

428 level = service_options.Extensions[annotations_pb2.auth_level] 

429 

430 validate_auth_level(level) 

431 

432 return level 

433 

434 

435def validate_auth_level(auth_level: AuthLevel.ValueType) -> None: 

436 # if unknown auth level, then it wasn't set and something's wrong 

437 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

438 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

439 

440 if auth_level not in { 440 ↛ 447line 440 didn't jump to line 447 because the condition on line 440 was never true

441 annotations_pb2.AUTH_LEVEL_OPEN, 

442 annotations_pb2.AUTH_LEVEL_JAILED, 

443 annotations_pb2.AUTH_LEVEL_SECURE, 

444 annotations_pb2.AUTH_LEVEL_EDITOR, 

445 annotations_pb2.AUTH_LEVEL_ADMIN, 

446 }: 

447 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

448 

449 

450def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None: 

451 if not auth_info: 

452 # if this isn't an open service, fail 

453 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

454 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

455 else: 

456 # a valid user session was found - check permissions 

457 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser: 

458 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

459 

460 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor: 

461 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

462 

463 # if the user is jailed and this isn't an open or jailed service, fail 

464 if auth_info.is_jailed and auth_level not in [ 

465 annotations_pb2.AUTH_LEVEL_OPEN, 

466 annotations_pb2.AUTH_LEVEL_JAILED, 

467 ]: 

468 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

469 

470 

471class MediaInterceptor(grpc.ServerInterceptor): 

472 """ 

473 Extracts an "Authorization: Bearer <hex>" header and calls the 

474 is_authorized function. Terminates the call with an HTTP error 

475 code if not authorized. 

476 

477 Also adds a session to called APIs. 

478 """ 

479 

480 def __init__(self, is_authorized: Callable[[str], bool]): 

481 self._is_authorized = is_authorized 

482 

483 def intercept_service[T, R]( 

484 self, 

485 continuation: Cont[T, R], 

486 handler_call_details: grpc.HandlerCallDetails, 

487 ) -> grpc.RpcMethodHandler[T, R]: 

488 handler = continuation(handler_call_details) 

489 if not handler: 489 ↛ 490line 489 didn't jump to line 490 because the condition on line 489 was never true

490 raise RuntimeError("No handler") 

491 

492 prev_func = handler.unary_unary 

493 if not prev_func: 493 ↛ 494line 493 didn't jump to line 494 because the condition on line 493 was never true

494 raise RuntimeError(f"No prev_function, {handler}") 

495 

496 metadata = dict(handler_call_details.invocation_metadata) 

497 

498 token = parse_api_key(metadata) 

499 

500 if not token or not self._is_authorized(token): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true

501 return unauthenticated_handler() 

502 

503 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

504 with session_scope() as session: 

505 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

506 

507 return grpc.unary_unary_rpc_method_handler( 

508 function_without_session, 

509 request_deserializer=handler.request_deserializer, 

510 response_serializer=handler.response_serializer, 

511 ) 

512 

513 

514class OTelInterceptor(grpc.ServerInterceptor): 

515 """ 

516 OpenTelemetry tracing 

517 """ 

518 

519 def __init__(self) -> None: 

520 self.tracer = trace.get_tracer(__name__) 

521 

522 def intercept_service[T, R]( 

523 self, 

524 continuation: Cont[T, R], 

525 handler_call_details: grpc.HandlerCallDetails, 

526 ) -> grpc.RpcMethodHandler[T, R]: 

527 handler = continuation(handler_call_details) 

528 if not handler: 

529 raise RuntimeError("No handler") 

530 

531 prev_func = handler.unary_unary 

532 if not prev_func: 

533 raise RuntimeError(f"No prev_function, {handler}") 

534 

535 method = handler_call_details.method 

536 

537 # method is of the form "/org.couchers.api.core.API/GetUser" 

538 _, service_name, method_name = method.split("/") 

539 

540 headers = dict(handler_call_details.invocation_metadata) 

541 

542 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

543 with self.tracer.start_as_current_span("handler") as rollspan: 

544 rollspan.set_attribute("rpc.method_full", method) 

545 rollspan.set_attribute("rpc.service", service_name) 

546 rollspan.set_attribute("rpc.method", method_name) 

547 

548 rollspan.set_attribute("rpc.thread", get_ident()) 

549 rollspan.set_attribute("rpc.pid", getpid()) 

550 

551 res = prev_func(request, context) 

552 

553 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

554 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

555 

556 return res 

557 

558 return grpc.unary_unary_rpc_method_handler( 

559 tracing_function, 

560 request_deserializer=handler.request_deserializer, 

561 response_serializer=handler.response_serializer, 

562 ) 

563 

564 

565class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

566 """ 

567 If the call resulted in a non-gRPC error, this strips away the error details. 

568 

569 It's important to put this first, so that it does not interfere with other interceptors. 

570 """ 

571 

572 def intercept_service[T, R]( 

573 self, 

574 continuation: Cont[T, R], 

575 handler_call_details: grpc.HandlerCallDetails, 

576 ) -> grpc.RpcMethodHandler[T, R]: 

577 handler = continuation(handler_call_details) 

578 if not handler: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true

579 raise RuntimeError("No handler") 

580 

581 prev_func = handler.unary_unary 

582 if not prev_func: 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true

583 raise RuntimeError(f"No prev_function, {handler}") 

584 

585 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

586 try: 

587 res = prev_func(req, context) 

588 except Exception as e: 

589 code = context.code() # type: ignore[attr-defined] 

590 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

591 if not code: 

592 logger.exception(e) 

593 logger.info("Probably an unknown error! Sanitizing...") 

594 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

595 else: 

596 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

597 raise e 

598 return res 

599 

600 return grpc.unary_unary_rpc_method_handler( 

601 sanitizing_function, 

602 request_deserializer=handler.request_deserializer, 

603 response_serializer=handler.response_serializer, 

604 )