Coverage for app/backend/src/couchers/interceptors.py: 87%

334 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 09:29 +0000

1import logging 

2from collections.abc import Callable, Mapping 

3from copy import deepcopy 

4from dataclasses import dataclass, field 

5from datetime import datetime, timedelta 

6from functools import cache 

7from os import getpid 

8from threading import get_ident 

9from time import perf_counter_ns 

10from traceback import format_exception 

11from typing import Any, NoReturn, cast, overload 

12from zoneinfo import ZoneInfo 

13 

14import grpc 

15import sentry_sdk 

16from google.protobuf.descriptor import Descriptor, ServiceDescriptor 

17from google.protobuf.descriptor_pool import DescriptorPool 

18from google.protobuf.message import Message 

19from opentelemetry import trace 

20from sqlalchemy import Function, literal_column, select 

21from sqlalchemy.dialects.postgresql import insert as pg_insert 

22from sqlalchemy.sql import func 

23 

24from couchers.config import config 

25from couchers.constants import ( 

26 CALL_CANCELLED_ERROR_MESSAGE, 

27 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE, 

28 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

29 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

30 PERMISSION_DENIED_ERROR_MESSAGE, 

31 UNAUTHORIZED_ERROR_MESSAGE, 

32 UNKNOWN_ERROR_MESSAGE, 

33) 

34from couchers.context import CouchersContext, make_interactive_context, make_media_context 

35from couchers.db import session_scope 

36from couchers.descriptor_pool import get_descriptor_pool 

37from couchers.i18n import LocalizationContext 

38from couchers.i18n.locales import to_supported_locale 

39from couchers.metrics import ( 

40 observe_api_call, 

41 observe_in_servicer_duration_histogram, 

42 observe_in_servicer_perf_histograms, 

43 observe_in_servicer_pool_wait_histogram, 

44 observe_in_servicer_serde_histogram, 

45 observe_in_servicer_setup_errors_counter, 

46 observe_in_servicer_setup_histogram, 

47) 

48from couchers.models import APICall, ClientPlatform, User, UserActivity, UserSession 

49from couchers.perf import PerfResult, read_perf, start_perf 

50from couchers.proto import annotations_pb2 

51from couchers.proto.annotations_pb2 import AuthLevel 

52from couchers.utils import ( 

53 create_lang_cookie, 

54 create_session_cookies, 

55 generate_sofa_cookie, 

56 now, 

57 parse_api_key, 

58 parse_session_cookie, 

59 parse_sofa_cookie, 

60 parse_ui_lang_cookie, 

61 parse_user_id_cookie, 

62) 

63 

64logger = logging.getLogger(__name__) 

65 

66 

67@dataclass(frozen=True, slots=True, kw_only=True) 

68class UserAuthInfo: 

69 """Information about an authenticated user session.""" 

70 

71 user_id: int 

72 is_jailed: bool 

73 is_editor: bool 

74 is_superuser: bool 

75 token_expiry: datetime 

76 ui_language_preference: str | None 

77 timezone: str | None 

78 token: str = field(repr=False) 

79 is_api_key: bool 

80 

81 

82def _binned_now() -> Function[Any]: 

83 return func.date_bin( 

84 literal_column("interval '1 hour'"), 

85 func.now(), 

86 literal_column("'2000-01-01'::timestamptz"), 

87 ) 

88 

89 

90def _try_get_and_update_user_details( 

91 token: str | None, 

92 is_api_key: bool, 

93 ip_address: str | None, 

94 user_agent: str | None, 

95 sofa: str | None, 

96 client_platform: ClientPlatform | None, 

97) -> UserAuthInfo | None: 

98 """ 

99 Tries to get session and user info corresponding to this token. 

100 

101 Also updates the user's last active time, token last active time, and increments API call count. 

102 

103 Returns UserAuthInfo if a valid session is found, None otherwise. 

104 """ 

105 if not token: 

106 return None 

107 

108 with session_scope() as session: 

109 result = session.execute( 

110 select(User, UserSession, User.is_jailed) 

111 .select_from(UserSession) 

112 .join(User, User.id == UserSession.user_id) 

113 .where(User.is_visible) 

114 .where(UserSession.token == token) 

115 .where(UserSession.is_valid) 

116 .where(UserSession.is_api_key == is_api_key) 

117 ).one_or_none() 

118 

119 if not result: 

120 return None 

121 

122 user, user_session, is_jailed = result._tuple() 

123 

124 # update user last active time if it's been a while 

125 if now() - user.last_active > timedelta(minutes=5): 

126 user.last_active = func.now() 

127 

128 # let's update the token 

129 user_session.last_seen = func.now() 

130 user_session.api_calls += 1 

131 

132 # upsert so concurrent requests for the same activity tuple don't race to insert and violate the index 

133 insert_stmt = pg_insert(UserActivity).values( 

134 user_id=user.id, 

135 period=_binned_now(), 

136 ip_address=ip_address, 

137 user_agent=user_agent, 

138 sofa=sofa, 

139 client_platform=client_platform, 

140 api_calls=1, 

141 ) 

142 session.execute( 

143 insert_stmt.on_conflict_do_update( 

144 index_elements=[ 

145 UserActivity.user_id, 

146 UserActivity.period, 

147 UserActivity.ip_address, 

148 UserActivity.user_agent, 

149 UserActivity.sofa, 

150 ], 

151 set_={ 

152 "api_calls": UserActivity.api_calls + 1, 

153 "client_platform": func.coalesce( 

154 insert_stmt.excluded.client_platform, UserActivity.client_platform 

155 ), 

156 }, 

157 ) 

158 ) 

159 

160 # build before committing to avoid expire_on_commit reloading these attributes 

161 auth_info = UserAuthInfo( 

162 user_id=user.id, 

163 is_jailed=is_jailed, 

164 is_editor=user.is_editor, 

165 is_superuser=user.is_superuser, 

166 token_expiry=user_session.expiry, 

167 ui_language_preference=user.ui_language_preference, 

168 timezone=user.timezone, 

169 token=token, 

170 is_api_key=is_api_key, 

171 ) 

172 

173 session.commit() 

174 

175 return auth_info 

176 

177 

178def abort_handler[T, R]( 

179 message: str, 

180 status_code: grpc.StatusCode, 

181) -> grpc.RpcMethodHandler[T, R]: 

182 def f(request: Any, context: CouchersContext) -> NoReturn: 

183 context.abort(status_code, message) 

184 

185 return grpc.unary_unary_rpc_method_handler(f) 

186 

187 

188def unauthenticated_handler[T, R]( 

189 message: str = UNAUTHORIZED_ERROR_MESSAGE, 

190 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

191) -> grpc.RpcMethodHandler[T, R]: 

192 return abort_handler(message, status_code) 

193 

194 

195@cache 

196def _descriptor_has_sensitive(descriptor: Descriptor) -> bool: 

197 """Whether this message type transitively contains any field marked sensitive.""" 

198 seen: set[Descriptor] = set() 

199 stack = [descriptor] 

200 while stack: 

201 d = stack.pop() 

202 if d in seen: 

203 continue 

204 seen.add(d) 

205 for f in d.fields: 

206 if f.GetOptions().Extensions[annotations_pb2.sensitive]: 

207 return True 

208 if f.message_type is not None: 

209 stack.append(f.message_type) 

210 return False 

211 

212 

213@dataclass(frozen=True, slots=True) 

214class _SanitizePlan: 

215 fields_to_clear: tuple[str, ...] 

216 fields_to_recurse: tuple[tuple[str, bool], ...] # (field name, is_repeated) 

217 

218 

219@cache 

220def _sanitize_plan(descriptor: Descriptor) -> _SanitizePlan: 

221 """For a message type, the fields to clear and the subfields worth recursing into.""" 

222 clear = [] 

223 recurse = [] 

224 for f in descriptor.fields: 

225 if f.GetOptions().Extensions[annotations_pb2.sensitive]: 

226 clear.append(f.name) 

227 elif f.message_type is not None and _descriptor_has_sensitive(f.message_type): 

228 recurse.append((f.name, f.is_repeated)) 

229 return _SanitizePlan(fields_to_clear=tuple(clear), fields_to_recurse=tuple(recurse)) 

230 

231 

232def _sanitize_message(message: Message) -> None: 

233 plan = _sanitize_plan(message.DESCRIPTOR) 

234 for name in plan.fields_to_clear: 

235 message.ClearField(name) 

236 for name, is_repeated in plan.fields_to_recurse: 

237 submessage = getattr(message, name) 

238 if not submessage: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true

239 continue 

240 if is_repeated: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true

241 for msg in submessage: 

242 _sanitize_message(msg) 

243 else: 

244 _sanitize_message(submessage) 

245 

246 

247@overload 

248def _sanitized_bytes(proto: Message) -> bytes: ... 

249@overload 

250def _sanitized_bytes(proto: None) -> None: ... 

251def _sanitized_bytes(proto: Message | None) -> bytes | None: 

252 """ 

253 Remove fields marked sensitive and return serialized bytes. 

254 

255 Sensitivity is static per message type, so the descriptor analysis is cached: messages whose type has no 

256 sensitive field anywhere serialize directly without a copy or walk. 

257 """ 

258 if not proto: 

259 return None 

260 

261 if not _descriptor_has_sensitive(proto.DESCRIPTOR): 

262 return proto.SerializeToString() 

263 

264 new_proto = deepcopy(proto) 

265 _sanitize_message(new_proto) 

266 return new_proto.SerializeToString() 

267 

268 

269def _store_log( 

270 *, 

271 method: str, 

272 status_code: str | None = None, 

273 duration: float, 

274 user_id: int | None, 

275 is_api_key: bool, 

276 request: Message, 

277 response: Message | None, 

278 traceback: str | None = None, 

279 perf_report: str | None = None, 

280 perf: PerfResult | None = None, 

281 client_platform: ClientPlatform | None = None, 

282 ip_address: str | None, 

283 user_agent: str | None, 

284 sofa: str | None, 

285) -> None: 

286 req_bytes = _sanitized_bytes(request) 

287 res_bytes = _sanitized_bytes(response) 

288 with session_scope() as session: 

289 response_truncated = False 

290 truncate_res_bytes_length = 16 * 1024 # 16 kB 

291 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 res_bytes = res_bytes[:truncate_res_bytes_length] 

293 response_truncated = True 

294 session.add( 

295 APICall( 

296 is_api_key=is_api_key, 

297 method=method, 

298 status_code=status_code, 

299 duration=duration, 

300 user_id=user_id, 

301 request=req_bytes, 

302 response=res_bytes, 

303 response_truncated=response_truncated, 

304 traceback=traceback, 

305 perf_report=perf_report, 

306 db_query_count=perf.db_query_count if perf else None, 

307 db_write_query_count=perf.db_write_query_count if perf else None, 

308 db_time_ms=perf.db_time_ms if perf else None, 

309 cpu_ms=perf.cpu_ms if perf else None, 

310 client_platform=client_platform, 

311 ip_address=ip_address, 

312 user_agent=user_agent, 

313 sofa=sofa, 

314 ) 

315 ) 

316 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

317 

318 

319type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

320 

321 

322class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

323 """ 

324 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

325 

326 Sets context.user_id and context.token if authenticated, otherwise 

327 terminates the call with an UNAUTHENTICATED error code. 

328 

329 2. Makes sure cookies are in sync. 

330 

331 3. Injects a session to get a database transaction. 

332 

333 4. Measures and logs the time it takes to service each incoming call. 

334 """ 

335 

336 def __init__(self) -> None: 

337 self._pool = get_descriptor_pool() 

338 

339 def intercept_service[T = Message, R = Message]( 

340 self, 

341 continuation: Cont[T, R], 

342 handler_call_details: grpc.HandlerCallDetails, 

343 ) -> grpc.RpcMethodHandler[T, R]: 

344 start = perf_counter_ns() 

345 

346 method = handler_call_details.method 

347 

348 # accounting for the auth/setup phase; the handler re-arms its own below 

349 start_perf() 

350 

351 try: 

352 try: 

353 auth_level = find_auth_level(self._pool, method) 

354 except AbortError as ae: 

355 return abort_handler(ae.msg, ae.code) 

356 

357 try: 

358 headers = parse_headers(dict(handler_call_details.invocation_metadata)) 

359 except BadHeaders: 

360 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE) 

361 

362 # if this is not present in prod, it's a Big Bug in config 

363 assert config.DEV or headers.ip_address is not None 

364 

365 auth_info = _try_get_and_update_user_details( 

366 headers.token, 

367 headers.is_api_key, 

368 headers.ip_address, 

369 headers.user_agent, 

370 headers.sofa, 

371 headers.client_platform, 

372 ) 

373 

374 try: 

375 check_permissions(auth_info, auth_level) 

376 except AbortError as ae: 

377 return unauthenticated_handler(ae.msg, ae.code) 

378 

379 if not (handler := continuation(handler_call_details)): 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true

380 raise RuntimeError(f"No handler in '{method}'") 

381 

382 if not (prev_function := handler.unary_unary): 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true

383 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

384 

385 if headers.sofa: 

386 sofa = headers.sofa 

387 new_sofa_cookie = None 

388 else: 

389 sofa, new_sofa_cookie = generate_sofa_cookie() 

390 

391 locale = to_supported_locale((auth_info.ui_language_preference if auth_info else headers.ui_lang) or "") 

392 loc_context = LocalizationContext( 

393 locale=locale, 

394 timezone=ZoneInfo((auth_info and auth_info.timezone) or "Etc/UTC"), 

395 ) 

396 

397 observe_in_servicer_setup_histogram(method, read_perf()) 

398 except Exception as e: 

399 observe_in_servicer_setup_errors_counter(method, type(e).__name__) 

400 sentry_sdk.set_tag("context", "servicer_setup") 

401 sentry_sdk.set_tag("method", method) 

402 sentry_sdk.capture_exception(e) 

403 return abort_handler(UNKNOWN_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

404 

405 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

406 couchers_context = make_interactive_context( 

407 grpc_context=grpc_context, 

408 user_id=auth_info.user_id if auth_info else None, 

409 is_api_key=auth_info.is_api_key if auth_info else False, 

410 token=auth_info.token if auth_info else None, 

411 localization=loc_context, 

412 sofa=sofa, 

413 ) 

414 

415 with session_scope() as session: 

416 # force the checkout now so its wait is timed here rather than hiding in the handler's first query 

417 pool_wait_start = perf_counter_ns() 

418 session.connection() 

419 observe_in_servicer_pool_wait_histogram(method, (perf_counter_ns() - pool_wait_start) / 1e9) 

420 start_perf() 

421 try: 

422 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

423 res = cast(Message, _res) 

424 # flush so pending ORM writes execute (and are counted) before we snapshot; a handler that only 

425 # session.add(...)s and returns would otherwise flush at commit, after read_perf() 

426 session.flush() 

427 perf = read_perf() 

428 finished = perf_counter_ns() 

429 duration = (finished - start) / 1e6 # ms 

430 _store_log( 

431 method=method, 

432 duration=duration, 

433 user_id=couchers_context._user_id, 

434 is_api_key=cast(bool, couchers_context._is_api_key), 

435 request=req, 

436 response=res, 

437 perf=perf, 

438 client_platform=headers.client_platform, 

439 ip_address=headers.ip_address, 

440 user_agent=headers.user_agent, 

441 sofa=sofa, 

442 ) 

443 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

444 observe_api_call(method, headers.client_platform) 

445 observe_in_servicer_perf_histograms(method, perf) 

446 except Exception as e: 

447 perf = read_perf() 

448 finished = perf_counter_ns() 

449 duration = (finished - start) / 1e6 # ms 

450 

451 if couchers_context._grpc_context: 451 ↛ 455line 451 didn't jump to line 455 because the condition on line 451 was always true

452 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined] 

453 code = getattr(context_code, "name", None) 

454 else: 

455 code = None 

456 

457 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

458 _store_log( 

459 method=method, 

460 status_code=code, 

461 duration=duration, 

462 user_id=couchers_context._user_id, 

463 is_api_key=cast(bool, couchers_context._is_api_key), 

464 request=req, 

465 response=None, 

466 traceback=traceback, 

467 perf=perf, 

468 client_platform=headers.client_platform, 

469 ip_address=headers.ip_address, 

470 user_agent=headers.user_agent, 

471 sofa=sofa, 

472 ) 

473 observe_in_servicer_duration_histogram( 

474 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

475 ) 

476 observe_api_call(method, headers.client_platform) 

477 observe_in_servicer_perf_histograms(method, perf) 

478 

479 if not code: 

480 sentry_sdk.set_tag("context", "servicer") 

481 sentry_sdk.set_tag("method", method) 

482 sentry_sdk.set_tag("user_agent", headers.user_agent) 

483 sentry_sdk.set_tag("ui_lang", loc_context.locale) 

484 sentry_sdk.set_user( 

485 { 

486 "id": couchers_context._user_id, 

487 "ip_address": headers.ip_address, 

488 "sofa": sofa[:12], 

489 } 

490 ) 

491 sentry_sdk.capture_exception(e) 

492 

493 raise e 

494 

495 if auth_info and not auth_info.is_api_key: 

496 # check the two cookies are in sync & that language preference cookie is correct 

497 if headers.user_id != str(auth_info.user_id): 497 ↛ 501line 497 didn't jump to line 501 because the condition on line 497 was always true

498 couchers_context.set_cookies( 

499 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry) 

500 ) 

501 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang: 

502 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference)) 

503 

504 if new_sofa_cookie: 

505 couchers_context.set_cookies([new_sofa_cookie]) 

506 

507 if not grpc_context.is_active(): 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true

508 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE) 

509 

510 couchers_context._send_cookies() 

511 

512 return res 

513 

514 def timed_serde[A, B](fn: Callable[[A], B], direction: str) -> Callable[[A], B]: 

515 def wrapped(arg: A) -> B: 

516 t0 = perf_counter_ns() 

517 result = fn(arg) 

518 observe_in_servicer_serde_histogram(method, direction, (perf_counter_ns() - t0) / 1e9) 

519 return result 

520 

521 return wrapped 

522 

523 # always set for our generated-proto methods, but grpc types them as optional 

524 assert handler.request_deserializer is not None and handler.response_serializer is not None 

525 return grpc.unary_unary_rpc_method_handler( 

526 function_without_couchers_stuff, 

527 request_deserializer=timed_serde(handler.request_deserializer, "deserialize"), 

528 response_serializer=timed_serde(handler.response_serializer, "serialize"), 

529 ) 

530 

531 

532@dataclass(frozen=True, slots=True, kw_only=True) 

533class CouchersHeaders: 

534 token: str | None = field(repr=False) 

535 is_api_key: bool 

536 ip_address: str | None 

537 user_agent: str | None 

538 client_platform: ClientPlatform | None 

539 ui_lang: str | None 

540 user_id: str | None 

541 sofa: str | None 

542 

543 

544def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders: 

545 if "cookie" in headers and "authorization" in headers: 

546 # for security reasons, only one of "cookie" or "authorization" can be present 

547 raise BadHeaders("Both cookies and authorization are present in headers") 

548 elif "cookie" in headers: 

549 # the session token is passed in cookies, i.e., in the `cookie` header 

550 token, is_api_key = parse_session_cookie(headers), False 

551 elif "authorization" in headers: 

552 # the session token is passed in the `authorization` header 

553 token, is_api_key = parse_api_key(headers), True 

554 else: 

555 # no session found 

556 token, is_api_key = None, False 

557 

558 ip_address = headers.get("x-couchers-real-ip") 

559 user_agent = headers.get("user-agent") 

560 

561 # the client (web app or native app) declares its platform via this header 

562 client_platform_raw = headers.get("x-couchers-client-platform") 

563 client_platform = ( 

564 ClientPlatform[client_platform_raw] 

565 if isinstance(client_platform_raw, str) and client_platform_raw in ClientPlatform.__members__ 

566 else None 

567 ) 

568 

569 ui_lang = parse_ui_lang_cookie(headers) 

570 user_id = parse_user_id_cookie(headers) 

571 sofa = parse_sofa_cookie(headers) 

572 

573 return CouchersHeaders( 

574 token=token, 

575 is_api_key=is_api_key, 

576 ip_address=ip_address if isinstance(ip_address, str) else None, 

577 user_agent=user_agent if isinstance(user_agent, str) else None, 

578 client_platform=client_platform, 

579 ui_lang=ui_lang, 

580 user_id=user_id, 

581 sofa=sofa, 

582 ) 

583 

584 

585class BadHeaders(Exception): 

586 pass 

587 

588 

589class AbortError(Exception): 

590 def __init__(self, msg: str, code: grpc.StatusCode): 

591 self.msg = msg 

592 self.code = code 

593 

594 

595def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType: 

596 # method is of the form "/org.couchers.api.core.API/GetUser" 

597 _, service_name, method_name = method.split("/") 

598 

599 try: 

600 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

601 service_options = service.GetOptions() 

602 except KeyError: 

603 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None 

604 

605 level = service_options.Extensions[annotations_pb2.auth_level] 

606 

607 validate_auth_level(level) 

608 

609 return level 

610 

611 

612def validate_auth_level(auth_level: AuthLevel.ValueType) -> None: 

613 # if unknown auth level, then it wasn't set and something's wrong 

614 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

615 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

616 

617 if auth_level not in { 617 ↛ 624line 617 didn't jump to line 624 because the condition on line 617 was never true

618 annotations_pb2.AUTH_LEVEL_OPEN, 

619 annotations_pb2.AUTH_LEVEL_JAILED, 

620 annotations_pb2.AUTH_LEVEL_SECURE, 

621 annotations_pb2.AUTH_LEVEL_EDITOR, 

622 annotations_pb2.AUTH_LEVEL_ADMIN, 

623 }: 

624 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

625 

626 

627def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None: 

628 if not auth_info: 

629 # if this isn't an open service, fail 

630 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

631 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

632 else: 

633 # a valid user session was found - check permissions 

634 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser: 

635 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

636 

637 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor: 

638 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

639 

640 # if the user is jailed and this isn't an open or jailed service, fail 

641 if auth_info.is_jailed and auth_level not in [ 

642 annotations_pb2.AUTH_LEVEL_OPEN, 

643 annotations_pb2.AUTH_LEVEL_JAILED, 

644 ]: 

645 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

646 

647 

648class MediaInterceptor(grpc.ServerInterceptor): 

649 """ 

650 Extracts an "Authorization: Bearer <hex>" header and calls the 

651 is_authorized function. Terminates the call with an HTTP error 

652 code if not authorized. 

653 

654 Also adds a session to called APIs. 

655 """ 

656 

657 def __init__(self, is_authorized: Callable[[str], bool]): 

658 self._is_authorized = is_authorized 

659 

660 def intercept_service[T, R]( 

661 self, 

662 continuation: Cont[T, R], 

663 handler_call_details: grpc.HandlerCallDetails, 

664 ) -> grpc.RpcMethodHandler[T, R]: 

665 handler = continuation(handler_call_details) 

666 if not handler: 666 ↛ 667line 666 didn't jump to line 667 because the condition on line 666 was never true

667 raise RuntimeError("No handler") 

668 

669 prev_func = handler.unary_unary 

670 if not prev_func: 670 ↛ 671line 670 didn't jump to line 671 because the condition on line 670 was never true

671 raise RuntimeError(f"No prev_function, {handler}") 

672 

673 metadata = dict(handler_call_details.invocation_metadata) 

674 

675 token = parse_api_key(metadata) 

676 

677 if not token or not self._is_authorized(token): 677 ↛ 678line 677 didn't jump to line 678 because the condition on line 677 was never true

678 return unauthenticated_handler() 

679 

680 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

681 with session_scope() as session: 

682 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

683 

684 return grpc.unary_unary_rpc_method_handler( 

685 function_without_session, 

686 request_deserializer=handler.request_deserializer, 

687 response_serializer=handler.response_serializer, 

688 ) 

689 

690 

691class OTelInterceptor(grpc.ServerInterceptor): 

692 """ 

693 OpenTelemetry tracing 

694 """ 

695 

696 def __init__(self) -> None: 

697 self.tracer = trace.get_tracer(__name__) 

698 

699 def intercept_service[T, R]( 

700 self, 

701 continuation: Cont[T, R], 

702 handler_call_details: grpc.HandlerCallDetails, 

703 ) -> grpc.RpcMethodHandler[T, R]: 

704 handler = continuation(handler_call_details) 

705 if not handler: 

706 raise RuntimeError("No handler") 

707 

708 prev_func = handler.unary_unary 

709 if not prev_func: 

710 raise RuntimeError(f"No prev_function, {handler}") 

711 

712 method = handler_call_details.method 

713 

714 # method is of the form "/org.couchers.api.core.API/GetUser" 

715 _, service_name, method_name = method.split("/") 

716 

717 headers = dict(handler_call_details.invocation_metadata) 

718 

719 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

720 with self.tracer.start_as_current_span("handler") as rollspan: 

721 rollspan.set_attribute("rpc.method_full", method) 

722 rollspan.set_attribute("rpc.service", service_name) 

723 rollspan.set_attribute("rpc.method", method_name) 

724 

725 rollspan.set_attribute("rpc.thread", get_ident()) 

726 rollspan.set_attribute("rpc.pid", getpid()) 

727 

728 res = prev_func(request, context) 

729 

730 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

731 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

732 

733 return res 

734 

735 return grpc.unary_unary_rpc_method_handler( 

736 tracing_function, 

737 request_deserializer=handler.request_deserializer, 

738 response_serializer=handler.response_serializer, 

739 ) 

740 

741 

742class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

743 """ 

744 If the call resulted in a non-gRPC error, this strips away the error details. 

745 

746 It's important to put this first, so that it does not interfere with other interceptors. 

747 """ 

748 

749 def intercept_service[T, R]( 

750 self, 

751 continuation: Cont[T, R], 

752 handler_call_details: grpc.HandlerCallDetails, 

753 ) -> grpc.RpcMethodHandler[T, R]: 

754 handler = continuation(handler_call_details) 

755 if not handler: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true

756 raise RuntimeError("No handler") 

757 

758 prev_func = handler.unary_unary 

759 if not prev_func: 759 ↛ 760line 759 didn't jump to line 760 because the condition on line 759 was never true

760 raise RuntimeError(f"No prev_function, {handler}") 

761 

762 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

763 try: 

764 res = prev_func(req, context) 

765 except Exception as e: 

766 code = context.code() # type: ignore[attr-defined] 

767 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

768 if not code: 

769 logger.exception(e) 

770 logger.info("Probably an unknown error! Sanitizing...") 

771 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

772 else: 

773 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

774 raise e 

775 return res 

776 

777 return grpc.unary_unary_rpc_method_handler( 

778 sanitizing_function, 

779 request_deserializer=handler.request_deserializer, 

780 response_serializer=handler.response_serializer, 

781 )