Coverage for src / couchers / interceptors.py: 84%

225 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-02 11:17 +0000

1import logging 

2from collections.abc import Callable 

3from copy import deepcopy 

4from dataclasses import dataclass 

5from datetime import datetime, timedelta 

6from os import getpid 

7from threading import get_ident 

8from time import perf_counter_ns 

9from traceback import format_exception 

10from typing import Any, NoReturn, cast 

11 

12import grpc 

13import sentry_sdk 

14from google.protobuf.descriptor import ServiceDescriptor 

15from google.protobuf.message import Message 

16from opentelemetry import trace 

17from sqlalchemy import Function, select 

18from sqlalchemy.sql import and_, func 

19 

20from couchers.constants import ( 

21 CALL_CANCELLED_ERROR_MESSAGE, 

22 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE, 

23 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

24 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

25 PERMISSION_DENIED_ERROR_MESSAGE, 

26 UNAUTHORIZED_ERROR_MESSAGE, 

27 UNKNOWN_ERROR_MESSAGE, 

28) 

29from couchers.context import CouchersContext, make_interactive_context, make_media_context 

30from couchers.db import session_scope 

31from couchers.descriptor_pool import get_descriptor_pool 

32from couchers.metrics import observe_in_servicer_duration_histogram 

33from couchers.models import APICall, User, UserActivity, UserSession 

34from couchers.proto import annotations_pb2 

35from couchers.utils import ( 

36 create_lang_cookie, 

37 create_session_cookies, 

38 now, 

39 parse_api_key, 

40 parse_session_cookie, 

41 parse_ui_lang_cookie, 

42 parse_user_id_cookie, 

43) 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48@dataclass(frozen=True, slots=True) 

49class UserAuthInfo: 

50 """ 

51 Information about an authenticated user session. 

52 

53 Returned by _try_get_and_update_user_details when a valid session is found. 

54 """ 

55 

56 user_id: int 

57 is_jailed: bool 

58 is_editor: bool 

59 is_superuser: bool 

60 token_expiry: datetime 

61 ui_language_preference: str | None 

62 

63 

64def _binned_now() -> Function[Any]: 

65 return func.date_bin("1 hour", func.now(), "2000-01-01") 

66 

67 

68def _try_get_and_update_user_details( 

69 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None 

70) -> UserAuthInfo | None: 

71 """ 

72 Tries to get session and user info corresponding to this token. 

73 

74 Also updates the user's last active time, token last active time, and increments API call count. 

75 

76 Returns UserAuthInfo if valid session found, None otherwise. 

77 """ 

78 if not token: 

79 return None 

80 

81 with session_scope() as session: 

82 result = session.execute( 

83 select(User, UserSession, UserActivity) 

84 .join(User, User.id == UserSession.user_id) 

85 .outerjoin( 

86 UserActivity, 

87 and_( 

88 UserActivity.user_id == User.id, 

89 UserActivity.period == _binned_now(), 

90 UserActivity.ip_address == ip_address, 

91 UserActivity.user_agent == user_agent, 

92 ), 

93 ) 

94 .where(User.is_visible) 

95 .where(UserSession.token == token) 

96 .where(UserSession.is_valid) 

97 .where(UserSession.is_api_key == is_api_key) 

98 ).one_or_none() 

99 

100 if not result: 

101 return None 

102 else: 

103 user, user_session, user_activity = result 

104 

105 # update user last active time if it's been a while 

106 if now() - user.last_active > timedelta(minutes=5): 

107 user.last_active = func.now() 

108 

109 # let's update the token 

110 user_session.last_seen = func.now() 

111 user_session.api_calls += 1 

112 

113 if user_activity: 

114 user_activity.api_calls += 1 

115 else: 

116 session.add( 

117 UserActivity( 

118 user_id=user.id, 

119 period=_binned_now(), 

120 ip_address=ip_address, 

121 user_agent=user_agent, 

122 api_calls=1, 

123 ) 

124 ) 

125 

126 session.commit() 

127 

128 return UserAuthInfo( 

129 user_id=user.id, 

130 is_jailed=user.is_jailed, 

131 is_editor=user.is_editor, 

132 is_superuser=user.is_superuser, 

133 token_expiry=user_session.expiry, 

134 ui_language_preference=user.ui_language_preference, 

135 ) 

136 

137 

138def abort_handler[T, R]( 

139 message: str, 

140 status_code: grpc.StatusCode, 

141) -> grpc.RpcMethodHandler[T, R]: 

142 def f(request: Any, context: CouchersContext) -> NoReturn: 

143 context.abort(status_code, message) 

144 

145 return grpc.unary_unary_rpc_method_handler(f) 

146 

147 

148def unauthenticated_handler[T, R]( 

149 message: str = UNAUTHORIZED_ERROR_MESSAGE, 

150 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

151) -> grpc.RpcMethodHandler[T, R]: 

152 return abort_handler(message, status_code) 

153 

154 

155def _sanitized_bytes(proto: Message | None) -> bytes | None: 

156 """ 

157 Remove fields marked sensitive and return serialized bytes 

158 """ 

159 if not proto: 

160 return None 

161 

162 new_proto = deepcopy(proto) 

163 

164 def _sanitize_message(message: Message) -> None: 

165 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

166 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

167 message.ClearField(name) 

168 if descriptor.message_type: 

169 submessage = getattr(message, name) 

170 if not submessage: 

171 continue 

172 if descriptor.is_repeated: 

173 for msg in submessage: 

174 _sanitize_message(msg) 

175 else: 

176 _sanitize_message(submessage) 

177 

178 _sanitize_message(new_proto) 

179 

180 return new_proto.SerializeToString() 

181 

182 

183def _store_log( 

184 *, 

185 method: str, 

186 status_code: grpc.StatusCode | None, 

187 duration: float, 

188 user_id: int | None, 

189 is_api_key: bool, 

190 request: Message, 

191 response: Message | None, 

192 traceback: str | None, 

193 perf_report: str | None, 

194 ip_address: str | None, 

195 user_agent: str | None, 

196) -> None: 

197 req_bytes = _sanitized_bytes(request) 

198 res_bytes = _sanitized_bytes(response) 

199 with session_scope() as session: 

200 response_truncated = False 

201 truncate_res_bytes_length = 16 * 1024 # 16 kB 

202 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 res_bytes = res_bytes[:truncate_res_bytes_length] 

204 response_truncated = True 

205 session.add( 

206 APICall( 

207 is_api_key=is_api_key, 

208 method=method, 

209 status_code=status_code, 

210 duration=duration, 

211 user_id=user_id, 

212 request=req_bytes, 

213 response=res_bytes, 

214 response_truncated=response_truncated, 

215 traceback=traceback, 

216 perf_report=perf_report, 

217 ip_address=ip_address, 

218 user_agent=user_agent, 

219 ) 

220 ) 

221 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

222 

223 

224type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

225 

226 

227class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

228 """ 

229 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

230 

231 Sets context.user_id and context.token if authenticated, otherwise 

232 terminates the call with an UNAUTHENTICATED error code. 

233 

234 2. Makes sure cookies are in sync. 

235 

236 3. Injects a session to get a database transaction. 

237 

238 4. Measures and logs the time it takes to service each incoming call. 

239 """ 

240 

241 def __init__(self) -> None: 

242 self._pool = get_descriptor_pool() 

243 

244 def intercept_service[T = Message, R = Message]( 

245 self, 

246 continuation: Cont[T, R], 

247 handler_call_details: grpc.HandlerCallDetails, 

248 ) -> grpc.RpcMethodHandler[T, R]: 

249 start = perf_counter_ns() 

250 

251 method = handler_call_details.method 

252 # method is of the form "/org.couchers.api.core.API/GetUser" 

253 _, service_name, method_name = method.split("/") 

254 

255 try: 

256 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

257 service_options = service.GetOptions() 

258 except KeyError: 

259 return abort_handler(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) 

260 

261 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

262 

263 # if unknown auth level, then it wasn't set and something's wrong 

264 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

265 return abort_handler(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

266 

267 assert auth_level in [ 

268 annotations_pb2.AUTH_LEVEL_OPEN, 

269 annotations_pb2.AUTH_LEVEL_JAILED, 

270 annotations_pb2.AUTH_LEVEL_SECURE, 

271 annotations_pb2.AUTH_LEVEL_EDITOR, 

272 annotations_pb2.AUTH_LEVEL_ADMIN, 

273 ] 

274 

275 headers = dict(handler_call_details.invocation_metadata) 

276 

277 if "cookie" in headers and "authorization" in headers: 

278 # for security reasons, only one of "cookie" or "authorization" can be present 

279 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE) 

280 elif "cookie" in headers: 

281 # the session token is passed in cookies, i.e. in the `cookie` header 

282 token, is_api_key = parse_session_cookie(headers), False 

283 elif "authorization" in headers: 

284 # the session token is passed in the `authorization` header 

285 token, is_api_key = parse_api_key(headers), True 

286 else: 

287 # no session found 

288 token, is_api_key = None, False 

289 

290 ip_address = cast(str | None, headers.get("x-couchers-real-ip")) 

291 user_agent = cast(str | None, headers.get("user-agent")) 

292 

293 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

294 

295 if not auth_info: 

296 # Invalid or no session - clear credentials 

297 token = None 

298 is_api_key = False 

299 

300 # if this isn't an open service, fail 

301 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

302 return unauthenticated_handler(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

303 else: 

304 # a valid user session was found - check permissions 

305 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser: 

306 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

307 

308 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor: 

309 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

310 

311 # if the user is jailed and this is isn't an open or jailed service, fail 

312 if auth_info.is_jailed and auth_level not in [ 

313 annotations_pb2.AUTH_LEVEL_OPEN, 

314 annotations_pb2.AUTH_LEVEL_JAILED, 

315 ]: 

316 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

317 

318 handler = continuation(handler_call_details) 

319 if not handler: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true

320 raise RuntimeError(f"No handler in '{method}'") 

321 

322 prev_function = handler.unary_unary 

323 if not prev_function: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true

324 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

325 

326 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

327 couchers_context: CouchersContext = make_interactive_context( 

328 grpc_context=grpc_context, 

329 user_id=auth_info.user_id if auth_info else None, 

330 is_api_key=is_api_key, 

331 token=token, 

332 ui_language_preference=auth_info.ui_language_preference if auth_info else None, 

333 ) 

334 with session_scope() as session: 

335 try: 

336 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

337 res = cast(Message, _res) 

338 finished = perf_counter_ns() 

339 duration = (finished - start) / 1e6 # ms 

340 _store_log( 

341 method=method, 

342 status_code=None, 

343 duration=duration, 

344 user_id=couchers_context._user_id, 

345 is_api_key=cast(bool, couchers_context._is_api_key), 

346 request=req, 

347 response=res, 

348 traceback=None, 

349 perf_report=None, 

350 ip_address=ip_address, 

351 user_agent=user_agent, 

352 ) 

353 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

354 except Exception as e: 

355 finished = perf_counter_ns() 

356 duration = (finished - start) / 1e6 # ms 

357 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr] 

358 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

359 _store_log( 

360 method=method, 

361 status_code=code, 

362 duration=duration, 

363 user_id=couchers_context._user_id, 

364 is_api_key=cast(bool, couchers_context._is_api_key), 

365 request=req, 

366 response=None, 

367 traceback=traceback, 

368 perf_report=None, 

369 ip_address=ip_address, 

370 user_agent=user_agent, 

371 ) 

372 observe_in_servicer_duration_histogram( 

373 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

374 ) 

375 

376 if not code: 

377 sentry_sdk.set_tag("context", "servicer") 

378 sentry_sdk.set_tag("method", method) 

379 sentry_sdk.capture_exception(e) 

380 

381 raise e 

382 

383 if auth_info and not is_api_key: 

384 # Sanity check. If auth_info is present, then we should have a token. 

385 if token is None: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true

386 raise RuntimeError(f"{token=}, {auth_info.token_expiry=}") 

387 

388 # check the two cookies are in sync & that language preference cookie is correct 

389 if parse_user_id_cookie(headers) != str(auth_info.user_id): 389 ↛ 393line 389 didn't jump to line 393 because the condition on line 389 was always true

390 couchers_context.set_cookies( 

391 create_session_cookies(token, auth_info.user_id, auth_info.token_expiry) 

392 ) 

393 if auth_info.ui_language_preference and auth_info.ui_language_preference != parse_ui_lang_cookie( 

394 headers 

395 ): 

396 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference)) 

397 

398 if not grpc_context.is_active(): 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true

399 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE) 

400 

401 couchers_context._send_cookies() 

402 

403 return res 

404 

405 return grpc.unary_unary_rpc_method_handler( 

406 function_without_couchers_stuff, 

407 request_deserializer=handler.request_deserializer, 

408 response_serializer=handler.response_serializer, 

409 ) 

410 

411 

412class MediaInterceptor(grpc.ServerInterceptor): 

413 """ 

414 Extracts an "Authorization: Bearer <hex>" header and calls the 

415 is_authorized function. Terminates the call with an HTTP error 

416 code if not authorized. 

417 

418 Also adds a session to called APIs. 

419 """ 

420 

421 def __init__(self, is_authorized: Callable[[str], bool]): 

422 self._is_authorized = is_authorized 

423 

424 def intercept_service[T, R]( 

425 self, 

426 continuation: Cont[T, R], 

427 handler_call_details: grpc.HandlerCallDetails, 

428 ) -> grpc.RpcMethodHandler[T, R]: 

429 handler = continuation(handler_call_details) 

430 if not handler: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true

431 raise RuntimeError("No handler") 

432 

433 prev_func = handler.unary_unary 

434 if not prev_func: 434 ↛ 435line 434 didn't jump to line 435 because the condition on line 434 was never true

435 raise RuntimeError(f"No prev_function, {handler}") 

436 

437 metadata = dict(handler_call_details.invocation_metadata) 

438 

439 token = parse_api_key(metadata) 

440 

441 if not token or not self._is_authorized(token): 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true

442 return unauthenticated_handler() 

443 

444 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

445 with session_scope() as session: 

446 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

447 

448 return grpc.unary_unary_rpc_method_handler( 

449 function_without_session, 

450 request_deserializer=handler.request_deserializer, 

451 response_serializer=handler.response_serializer, 

452 ) 

453 

454 

455class OTelInterceptor(grpc.ServerInterceptor): 

456 """ 

457 OpenTelemetry tracing 

458 """ 

459 

460 def __init__(self) -> None: 

461 self.tracer = trace.get_tracer(__name__) 

462 

463 def intercept_service[T, R]( 

464 self, 

465 continuation: Cont[T, R], 

466 handler_call_details: grpc.HandlerCallDetails, 

467 ) -> grpc.RpcMethodHandler[T, R]: 

468 handler = continuation(handler_call_details) 

469 if not handler: 

470 raise RuntimeError("No handler") 

471 

472 prev_func = handler.unary_unary 

473 if not prev_func: 

474 raise RuntimeError(f"No prev_function, {handler}") 

475 

476 method = handler_call_details.method 

477 

478 # method is of the form "/org.couchers.api.core.API/GetUser" 

479 _, service_name, method_name = method.split("/") 

480 

481 headers = dict(handler_call_details.invocation_metadata) 

482 

483 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

484 with self.tracer.start_as_current_span("handler") as rollspan: 

485 rollspan.set_attribute("rpc.method_full", method) 

486 rollspan.set_attribute("rpc.service", service_name) 

487 rollspan.set_attribute("rpc.method", method_name) 

488 

489 rollspan.set_attribute("rpc.thread", get_ident()) 

490 rollspan.set_attribute("rpc.pid", getpid()) 

491 

492 res = prev_func(request, context) 

493 

494 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

495 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

496 

497 return res 

498 

499 return grpc.unary_unary_rpc_method_handler( 

500 tracing_function, 

501 request_deserializer=handler.request_deserializer, 

502 response_serializer=handler.response_serializer, 

503 ) 

504 

505 

506class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

507 """ 

508 If the call resulted in a non-gRPC error, this strips away the error details. 

509 

510 It's important to put this first, so that it does not interfere with other interceptors. 

511 """ 

512 

513 def intercept_service[T, R]( 

514 self, 

515 continuation: Cont[T, R], 

516 handler_call_details: grpc.HandlerCallDetails, 

517 ) -> grpc.RpcMethodHandler[T, R]: 

518 handler = continuation(handler_call_details) 

519 if not handler: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true

520 raise RuntimeError("No handler") 

521 

522 prev_func = handler.unary_unary 

523 if not prev_func: 523 ↛ 524line 523 didn't jump to line 524 because the condition on line 523 was never true

524 raise RuntimeError(f"No prev_function, {handler}") 

525 

526 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

527 try: 

528 res = prev_func(req, context) 

529 except Exception as e: 

530 code = context.code() # type: ignore[attr-defined] 

531 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

532 if not code: 

533 logger.exception(e) 

534 logger.info("Probably an unknown error! Sanitizing...") 

535 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

536 else: 

537 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

538 raise e 

539 return res 

540 

541 return grpc.unary_unary_rpc_method_handler( 

542 sanitizing_function, 

543 request_deserializer=handler.request_deserializer, 

544 response_serializer=handler.response_serializer, 

545 )