Coverage for src/couchers/interceptors.py: 86%

232 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-02 11:18 +0000

1import logging 

2from collections.abc import Callable 

3from copy import deepcopy 

4from dataclasses import dataclass 

5from datetime import datetime, timedelta 

6from os import getpid 

7from threading import get_ident 

8from time import perf_counter_ns 

9from traceback import format_exception 

10from typing import Any, Never, NoReturn, cast 

11 

12import grpc 

13import sentry_sdk 

14from google.protobuf.descriptor import ServiceDescriptor 

15from google.protobuf.message import Message 

16from opentelemetry import trace 

17from sqlalchemy import Function 

18from sqlalchemy.sql import and_, func 

19 

20from couchers.constants import ( 

21 CALL_CANCELLED_ERROR_MESSAGE, 

22 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE, 

23 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

24 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

25 PERMISSION_DENIED_ERROR_MESSAGE, 

26 UNAUTHORIZED_ERROR_MESSAGE, 

27 UNKNOWN_ERROR_MESSAGE, 

28) 

29from couchers.context import CouchersContext, make_interactive_context, make_media_context 

30from couchers.db import session_scope 

31from couchers.descriptor_pool import get_descriptor_pool 

32from couchers.metrics import observe_in_servicer_duration_histogram 

33from couchers.models import APICall, User, UserActivity, UserSession 

34from couchers.proto import annotations_pb2 

35from couchers.sql import couchers_select as select 

36from couchers.utils import ( 

37 create_lang_cookie, 

38 create_session_cookies, 

39 now, 

40 parse_api_key, 

41 parse_session_cookie, 

42 parse_ui_lang_cookie, 

43 parse_user_id_cookie, 

44) 

45 

46logger = logging.getLogger(__name__) 

47 

48 

49@dataclass(frozen=True, slots=True) 

50class UserAuthInfo: 

51 """ 

52 Information about an authenticated user session. 

53 

54 Returned by _try_get_and_update_user_details when a valid session is found. 

55 """ 

56 

57 user_id: int 

58 is_jailed: bool 

59 is_editor: bool 

60 is_superuser: bool 

61 token_expiry: datetime 

62 ui_language_preference: str | None 

63 

64 

65def _binned_now() -> Function[Any]: 

66 return func.date_bin("1 hour", func.now(), "2000-01-01") 

67 

68 

69def _try_get_and_update_user_details( 

70 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None 

71) -> UserAuthInfo | None: 

72 """ 

73 Tries to get session and user info corresponding to this token. 

74 

75 Also updates the user's last active time, token last active time, and increments API call count. 

76 

77 Returns UserAuthInfo if valid session found, None otherwise. 

78 """ 

79 if not token: 

80 return None 

81 

82 with session_scope() as session: 

83 result = session.execute( 

84 select(User, UserSession, UserActivity) 

85 .join(User, User.id == UserSession.user_id) 

86 .outerjoin( 

87 UserActivity, 

88 and_( 

89 UserActivity.user_id == User.id, 

90 UserActivity.period == _binned_now(), 

91 UserActivity.ip_address == ip_address, 

92 UserActivity.user_agent == user_agent, 

93 ), 

94 ) 

95 .where(User.is_visible) 

96 .where(UserSession.token == token) 

97 .where(UserSession.is_valid) 

98 .where(UserSession.is_api_key == is_api_key) 

99 ).one_or_none() 

100 

101 if not result: 

102 return None 

103 else: 

104 user, user_session, user_activity = result 

105 

106 # update user last active time if it's been a while 

107 if now() - user.last_active > timedelta(minutes=5): 

108 user.last_active = func.now() 

109 

110 # let's update the token 

111 user_session.last_seen = func.now() 

112 user_session.api_calls += 1 

113 

114 if user_activity: 

115 user_activity.api_calls += 1 

116 else: 

117 session.add( 

118 UserActivity( 

119 user_id=user.id, 

120 period=_binned_now(), 

121 ip_address=ip_address, 

122 user_agent=user_agent, 

123 api_calls=1, 

124 ) 

125 ) 

126 

127 session.commit() 

128 

129 return UserAuthInfo( 

130 user_id=user.id, 

131 is_jailed=user.is_jailed, 

132 is_editor=user.is_editor, 

133 is_superuser=user.is_superuser, 

134 token_expiry=user_session.expiry, 

135 ui_language_preference=user.ui_language_preference, 

136 ) 

137 

138 

139# We have to lie with R | NoReturn to please mypy. It should be NoReturn. 

140def abort_handler[T, R]( 

141 message: str, 

142 status_code: grpc.StatusCode, 

143) -> "grpc.RpcMethodHandler[T, R | NoReturn]": 

144 def f(request: Any, context: CouchersContext) -> NoReturn: 

145 context.abort(status_code, message) 

146 

147 return grpc.unary_unary_rpc_method_handler(f) 

148 

149 

150def unauthenticated_handler[T, R]( 

151 message: str = UNAUTHORIZED_ERROR_MESSAGE, 

152 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

153) -> "grpc.RpcMethodHandler[T, R | NoReturn]": 

154 return abort_handler(message, status_code) 

155 

156 

157def _sanitized_bytes(proto: Message | None) -> bytes | None: 

158 """ 

159 Remove fields marked sensitive and return serialized bytes 

160 """ 

161 if not proto: 

162 return None 

163 

164 new_proto = deepcopy(proto) 

165 

166 def _sanitize_message(message: Message) -> None: 

167 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

168 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

169 message.ClearField(name) 

170 if descriptor.message_type: 

171 submessage = getattr(message, name) 

172 if not submessage: 

173 continue 

174 if descriptor.label == descriptor.LABEL_REPEATED: 

175 for msg in submessage: 

176 _sanitize_message(msg) 

177 else: 

178 _sanitize_message(submessage) 

179 

180 _sanitize_message(new_proto) 

181 

182 return new_proto.SerializeToString() 

183 

184 

185def _store_log( 

186 *, 

187 method: str, 

188 status_code: grpc.StatusCode | None, 

189 duration: float, 

190 user_id: int | None, 

191 is_api_key: bool, 

192 request: Message, 

193 response: Message | None, 

194 traceback: str | None, 

195 perf_report: str | None, 

196 ip_address: str | None, 

197 user_agent: str | None, 

198) -> None: 

199 req_bytes = _sanitized_bytes(request) 

200 res_bytes = _sanitized_bytes(response) 

201 with session_scope() as session: 

202 response_truncated = False 

203 truncate_res_bytes_length = 16 * 1024 # 16 kB 

204 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

205 res_bytes = res_bytes[:truncate_res_bytes_length] 

206 response_truncated = True 

207 session.add( 

208 APICall( 

209 is_api_key=is_api_key, 

210 method=method, 

211 status_code=status_code, 

212 duration=duration, 

213 user_id=user_id, 

214 request=req_bytes, 

215 response=res_bytes, 

216 response_truncated=response_truncated, 

217 traceback=traceback, 

218 perf_report=perf_report, 

219 ip_address=ip_address, 

220 user_agent=user_agent, 

221 ) 

222 ) 

223 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

224 

225 

226type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

227 

228 

229class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

230 """ 

231 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

232 

233 Sets context.user_id and context.token if authenticated, otherwise 

234 terminates the call with an UNAUTHENTICATED error code. 

235 

236 2. Makes sure cookies are in sync. 

237 

238 3. Injects a session to get a database transaction. 

239 

240 4. Measures and logs the time it takes to service each incoming call. 

241 """ 

242 

243 def __init__(self) -> None: 

244 self._pool = get_descriptor_pool() 

245 

246 def intercept_service[T = Message, R = Message]( 

247 self, 

248 continuation: Cont[T, R], 

249 handler_call_details: grpc.HandlerCallDetails, 

250 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

251 start = perf_counter_ns() 

252 

253 method = handler_call_details.method 

254 # method is of the form "/org.couchers.api.core.API/GetUser" 

255 _, service_name, method_name = method.split("/") 

256 

257 try: 

258 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

259 service_options = service.GetOptions() 

260 except KeyError: 

261 return abort_handler(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) 

262 

263 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

264 

265 # if unknown auth level, then it wasn't set and something's wrong 

266 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

267 return abort_handler(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL) 

268 

269 assert auth_level in [ 

270 annotations_pb2.AUTH_LEVEL_OPEN, 

271 annotations_pb2.AUTH_LEVEL_JAILED, 

272 annotations_pb2.AUTH_LEVEL_SECURE, 

273 annotations_pb2.AUTH_LEVEL_EDITOR, 

274 annotations_pb2.AUTH_LEVEL_ADMIN, 

275 ] 

276 

277 headers = dict(handler_call_details.invocation_metadata) 

278 

279 if "cookie" in headers and "authorization" in headers: 

280 # for security reasons, only one of "cookie" or "authorization" can be present 

281 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE) 

282 elif "cookie" in headers: 

283 # the session token is passed in cookies, i.e. in the `cookie` header 

284 token, is_api_key = parse_session_cookie(headers), False 

285 elif "authorization" in headers: 

286 # the session token is passed in the `authorization` header 

287 token, is_api_key = parse_api_key(headers), True 

288 else: 

289 # no session found 

290 token, is_api_key = None, False 

291 

292 ip_address = cast(str | None, headers.get("x-couchers-real-ip")) 

293 user_agent = cast(str | None, headers.get("user-agent")) 

294 

295 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

296 

297 if not auth_info: 

298 # Invalid or no session - clear credentials 

299 token = None 

300 is_api_key = False 

301 

302 # if this isn't an open service, fail 

303 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

304 return unauthenticated_handler(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

305 else: 

306 # a valid user session was found - check permissions 

307 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser: 

308 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

309 

310 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor: 

311 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED) 

312 

313 # if the user is jailed and this is isn't an open or jailed service, fail 

314 if auth_info.is_jailed and auth_level not in [ 

315 annotations_pb2.AUTH_LEVEL_OPEN, 

316 annotations_pb2.AUTH_LEVEL_JAILED, 

317 ]: 

318 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED) 

319 

320 handler = continuation(handler_call_details) 

321 if not handler: 

322 raise RuntimeError(f"No handler in '{method}'") 

323 

324 prev_function = handler.unary_unary 

325 if not prev_function: 

326 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

327 

328 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

329 couchers_context: CouchersContext = make_interactive_context( 

330 grpc_context=grpc_context, 

331 user_id=auth_info.user_id if auth_info else None, 

332 is_api_key=is_api_key, 

333 token=token, 

334 ui_language_preference=auth_info.ui_language_preference if auth_info else None, 

335 ) 

336 with session_scope() as session: 

337 try: 

338 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

339 res = cast(Message, _res) 

340 finished = perf_counter_ns() 

341 duration = (finished - start) / 1e6 # ms 

342 _store_log( 

343 method=method, 

344 status_code=None, 

345 duration=duration, 

346 user_id=couchers_context._user_id, 

347 is_api_key=cast(bool, couchers_context._is_api_key), 

348 request=req, 

349 response=res, 

350 traceback=None, 

351 perf_report=None, 

352 ip_address=ip_address, 

353 user_agent=user_agent, 

354 ) 

355 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

356 except Exception as e: 

357 finished = perf_counter_ns() 

358 duration = (finished - start) / 1e6 # ms 

359 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr] 

360 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

361 _store_log( 

362 method=method, 

363 status_code=code, 

364 duration=duration, 

365 user_id=couchers_context._user_id, 

366 is_api_key=cast(bool, couchers_context._is_api_key), 

367 request=req, 

368 response=None, 

369 traceback=traceback, 

370 perf_report=None, 

371 ip_address=ip_address, 

372 user_agent=user_agent, 

373 ) 

374 observe_in_servicer_duration_histogram( 

375 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

376 ) 

377 

378 if not code: 

379 sentry_sdk.set_tag("context", "servicer") 

380 sentry_sdk.set_tag("method", method) 

381 sentry_sdk.capture_exception(e) 

382 

383 raise e 

384 

385 if auth_info and not is_api_key: 

386 # Sanity check. If auth_info is present, then we should have a token. 

387 if token is None: 

388 raise RuntimeError(f"{token=}, {auth_info.token_expiry=}") 

389 

390 # check the two cookies are in sync & that language preference cookie is correct 

391 if parse_user_id_cookie(headers) != str(auth_info.user_id): 

392 couchers_context.set_cookies( 

393 create_session_cookies(token, auth_info.user_id, auth_info.token_expiry) 

394 ) 

395 if auth_info.ui_language_preference and auth_info.ui_language_preference != parse_ui_lang_cookie( 

396 headers 

397 ): 

398 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference)) 

399 

400 if not grpc_context.is_active(): 

401 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE) 

402 

403 couchers_context._send_cookies() 

404 

405 return res 

406 

407 return grpc.unary_unary_rpc_method_handler( 

408 function_without_couchers_stuff, 

409 request_deserializer=handler.request_deserializer, 

410 response_serializer=handler.response_serializer, 

411 ) 

412 

413 

414class MediaInterceptor(grpc.ServerInterceptor): 

415 """ 

416 Extracts an "Authorization: Bearer <hex>" header and calls the 

417 is_authorized function. Terminates the call with an HTTP error 

418 code if not authorized. 

419 

420 Also adds a session to called APIs. 

421 """ 

422 

423 def __init__(self, is_authorized: Callable[[str], bool]): 

424 self._is_authorized = is_authorized 

425 

426 def intercept_service[T, R]( 

427 self, 

428 continuation: Cont[T, R], 

429 handler_call_details: grpc.HandlerCallDetails, 

430 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

431 handler = continuation(handler_call_details) 

432 if not handler: 

433 raise RuntimeError("No handler") 

434 

435 prev_func = handler.unary_unary 

436 if not prev_func: 

437 raise RuntimeError(f"No prev_function, {handler}") 

438 

439 metadata = dict(handler_call_details.invocation_metadata) 

440 

441 token = parse_api_key(metadata) 

442 

443 if not token or not self._is_authorized(token): 

444 return unauthenticated_handler() 

445 

446 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

447 with session_scope() as session: 

448 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

449 

450 return grpc.unary_unary_rpc_method_handler( 

451 function_without_session, 

452 request_deserializer=handler.request_deserializer, 

453 response_serializer=handler.response_serializer, 

454 ) 

455 

456 

457class OTelInterceptor(grpc.ServerInterceptor): 

458 """ 

459 OpenTelemetry tracing 

460 """ 

461 

462 def __init__(self) -> None: 

463 self.tracer = trace.get_tracer(__name__) 

464 

465 def intercept_service[T, R]( 

466 self, 

467 continuation: Cont[T, R], 

468 handler_call_details: grpc.HandlerCallDetails, 

469 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

470 handler = continuation(handler_call_details) 

471 if not handler: 

472 raise RuntimeError("No handler") 

473 

474 prev_func = handler.unary_unary 

475 if not prev_func: 

476 raise RuntimeError(f"No prev_function, {handler}") 

477 

478 method = handler_call_details.method 

479 

480 # method is of the form "/org.couchers.api.core.API/GetUser" 

481 _, service_name, method_name = method.split("/") 

482 

483 headers = dict(handler_call_details.invocation_metadata) 

484 

485 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

486 with self.tracer.start_as_current_span("handler") as rollspan: 

487 rollspan.set_attribute("rpc.method_full", method) 

488 rollspan.set_attribute("rpc.service", service_name) 

489 rollspan.set_attribute("rpc.method", method_name) 

490 

491 rollspan.set_attribute("rpc.thread", get_ident()) 

492 rollspan.set_attribute("rpc.pid", getpid()) 

493 

494 res = prev_func(request, context) 

495 

496 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

497 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

498 

499 return res 

500 

501 return grpc.unary_unary_rpc_method_handler( 

502 tracing_function, 

503 request_deserializer=handler.request_deserializer, 

504 response_serializer=handler.response_serializer, 

505 ) 

506 

507 

508class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

509 """ 

510 If the call resulted in a non-gRPC error, this strips away the error details. 

511 

512 It's important to put this first, so that it does not interfere with other interceptors. 

513 """ 

514 

515 def intercept_service[T, R]( 

516 self, 

517 continuation: Cont[T, R], 

518 handler_call_details: grpc.HandlerCallDetails, 

519 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

520 handler = continuation(handler_call_details) 

521 if not handler: 

522 raise RuntimeError("No handler") 

523 

524 prev_func = handler.unary_unary 

525 if not prev_func: 

526 raise RuntimeError(f"No prev_function, {handler}") 

527 

528 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

529 try: 

530 res = prev_func(req, context) 

531 except Exception as e: 

532 code = context.code() # type: ignore[attr-defined] 

533 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

534 if not code: 

535 logger.exception(e) 

536 logger.info("Probably an unknown error! Sanitizing...") 

537 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

538 else: 

539 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

540 raise e 

541 return res 

542 

543 return grpc.unary_unary_rpc_method_handler( 

544 sanitizing_function, 

545 request_deserializer=handler.request_deserializer, 

546 response_serializer=handler.response_serializer, 

547 )