Coverage for src/couchers/interceptors.py: 86%

226 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-18 14:01 +0000

1import logging 

2from collections.abc import Callable 

3from copy import deepcopy 

4from datetime import datetime, timedelta 

5from os import getpid 

6from threading import get_ident 

7from time import perf_counter_ns 

8from traceback import format_exception 

9from typing import Any, Never, NoReturn, cast 

10 

11import grpc 

12import sentry_sdk 

13from google.protobuf.descriptor import ServiceDescriptor 

14from google.protobuf.message import Message 

15from opentelemetry import trace 

16from sqlalchemy import Function 

17from sqlalchemy.sql import and_, func 

18 

19from couchers.constants import UNKNOWN_ERROR_MESSAGE 

20from couchers.context import CouchersContext, make_interactive_context, make_media_context 

21from couchers.db import session_scope 

22from couchers.descriptor_pool import get_descriptor_pool 

23from couchers.metrics import observe_in_servicer_duration_histogram 

24from couchers.models import APICall, User, UserActivity, UserSession 

25from couchers.proto import annotations_pb2 

26from couchers.sql import couchers_select as select 

27from couchers.utils import ( 

28 create_lang_cookie, 

29 create_session_cookies, 

30 now, 

31 parse_api_key, 

32 parse_session_cookie, 

33 parse_ui_lang_cookie, 

34 parse_user_id_cookie, 

35) 

36 

37logger = logging.getLogger(__name__) 

38 

39 

40def _binned_now() -> Function[Any]: 

41 return func.date_bin("1 hour", func.now(), "2000-01-01") 

42 

43 

44def _try_get_and_update_user_details( 

45 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None 

46) -> tuple[int, bool, bool, datetime, str | None] | None: 

47 """ 

48 Tries to get session and user info corresponding to this token. 

49 

50 Also updates the user's last active time, token last active time, and increments API call count. 

51 """ 

52 if not token: 

53 return None 

54 

55 with session_scope() as session: 

56 result = session.execute( 

57 select(User, UserSession, UserActivity) 

58 .join(User, User.id == UserSession.user_id) 

59 .outerjoin( 

60 UserActivity, 

61 and_( 

62 UserActivity.user_id == User.id, 

63 UserActivity.period == _binned_now(), 

64 UserActivity.ip_address == ip_address, 

65 UserActivity.user_agent == user_agent, 

66 ), 

67 ) 

68 .where(User.is_visible) 

69 .where(UserSession.token == token) 

70 .where(UserSession.is_valid) 

71 .where(UserSession.is_api_key == is_api_key) 

72 ).one_or_none() 

73 

74 if not result: 

75 return None 

76 else: 

77 user, user_session, user_activity = result 

78 

79 # update user last active time if it's been a while 

80 if now() - user.last_active > timedelta(minutes=5): 

81 user.last_active = func.now() 

82 

83 # let's update the token 

84 user_session.last_seen = func.now() 

85 user_session.api_calls += 1 

86 

87 if user_activity: 

88 user_activity.api_calls += 1 

89 else: 

90 session.add( 

91 UserActivity( 

92 user_id=user.id, 

93 period=_binned_now(), 

94 ip_address=ip_address, 

95 user_agent=user_agent, 

96 api_calls=1, 

97 ) 

98 ) 

99 

100 session.commit() 

101 

102 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference 

103 

104 

105# We have to lie with R | NoReturn to please mypy. It should be NoReturn. 

106def abort_handler[T, R]( 

107 message: str, 

108 status_code: grpc.StatusCode, 

109) -> "grpc.RpcMethodHandler[T, R | NoReturn]": 

110 def f(request: Any, context: CouchersContext) -> NoReturn: 

111 context.abort(status_code, message) 

112 

113 return grpc.unary_unary_rpc_method_handler(f) 

114 

115 

116def unauthenticated_handler[T, R]( 

117 message: str = "Unauthorized", 

118 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED, 

119) -> "grpc.RpcMethodHandler[T, R | NoReturn]": 

120 return abort_handler(message, status_code) 

121 

122 

123def _sanitized_bytes(proto: Message | None) -> bytes | None: 

124 """ 

125 Remove fields marked sensitive and return serialized bytes 

126 """ 

127 if not proto: 

128 return None 

129 

130 new_proto = deepcopy(proto) 

131 

132 def _sanitize_message(message: Message) -> None: 

133 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

134 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

135 message.ClearField(name) 

136 if descriptor.message_type: 

137 submessage = getattr(message, name) 

138 if not submessage: 

139 continue 

140 if descriptor.label == descriptor.LABEL_REPEATED: 

141 for msg in submessage: 

142 _sanitize_message(msg) 

143 else: 

144 _sanitize_message(submessage) 

145 

146 _sanitize_message(new_proto) 

147 

148 return new_proto.SerializeToString() 

149 

150 

151def _store_log( 

152 *, 

153 method: str, 

154 status_code: grpc.StatusCode | None, 

155 duration: float, 

156 user_id: int | None, 

157 is_api_key: bool, 

158 request: Message, 

159 response: Message | None, 

160 traceback: str | None, 

161 perf_report: str | None, 

162 ip_address: str | None, 

163 user_agent: str | None, 

164) -> None: 

165 req_bytes = _sanitized_bytes(request) 

166 res_bytes = _sanitized_bytes(response) 

167 with session_scope() as session: 

168 response_truncated = False 

169 truncate_res_bytes_length = 16 * 1024 # 16 kB 

170 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

171 res_bytes = res_bytes[:truncate_res_bytes_length] 

172 response_truncated = True 

173 session.add( 

174 APICall( 

175 is_api_key=is_api_key, 

176 method=method, 

177 status_code=status_code, 

178 duration=duration, 

179 user_id=user_id, 

180 request=req_bytes, 

181 response=res_bytes, 

182 response_truncated=response_truncated, 

183 traceback=traceback, 

184 perf_report=perf_report, 

185 ip_address=ip_address, 

186 user_agent=user_agent, 

187 ) 

188 ) 

189 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

190 

191 

192type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None] 

193 

194 

195class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

196 """ 

197 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

198 

199 Sets context.user_id and context.token if authenticated, otherwise 

200 terminates the call with an UNAUTHENTICATED error code. 

201 

202 2. Makes sure cookies are in sync. 

203 

204 3. Injects a session to get a database transaction. 

205 

206 4. Measures and logs the time it takes to service each incoming call. 

207 """ 

208 

209 def __init__(self) -> None: 

210 self._pool = get_descriptor_pool() 

211 

212 def intercept_service[T = Message, R = Message]( 

213 self, 

214 continuation: Cont[T, R], 

215 handler_call_details: grpc.HandlerCallDetails, 

216 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

217 start = perf_counter_ns() 

218 

219 method = handler_call_details.method 

220 # method is of the form "/org.couchers.api.core.API/GetUser" 

221 _, service_name, method_name = method.split("/") 

222 

223 try: 

224 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call] 

225 service_options = service.GetOptions() 

226 except KeyError: 

227 return abort_handler( 

228 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

229 ) 

230 

231 auth_level: Any = service_options.Extensions[annotations_pb2.auth_level] # type: ignore[index] 

232 

233 # if unknown auth level, then it wasn't set and something's wrong 

234 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

235 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

236 

237 assert auth_level in [ 

238 annotations_pb2.AUTH_LEVEL_OPEN, 

239 annotations_pb2.AUTH_LEVEL_JAILED, 

240 annotations_pb2.AUTH_LEVEL_SECURE, 

241 annotations_pb2.AUTH_LEVEL_ADMIN, 

242 ] 

243 

244 headers = dict(handler_call_details.invocation_metadata) 

245 

246 if "cookie" in headers and "authorization" in headers: 

247 # for security reasons, only one of "cookie" or "authorization" can be present 

248 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

249 elif "cookie" in headers: 

250 # the session token is passed in cookies, i.e. in the `cookie` header 

251 token, is_api_key = parse_session_cookie(headers), False 

252 elif "authorization" in headers: 

253 # the session token is passed in the `authorization` header 

254 token, is_api_key = parse_api_key(headers), True 

255 else: 

256 # no session found 

257 token, is_api_key = None, False 

258 

259 ip_address = cast(str | None, headers.get("x-couchers-real-ip")) 

260 user_agent = cast(str | None, headers.get("user-agent")) 

261 

262 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

263 # auth_info is now filled if and only if this is a valid session 

264 if not auth_info: 

265 token = None 

266 is_api_key = False 

267 token_expiry = None 

268 user_id = None 

269 ui_language_preference = None 

270 

271 # if no session was found and this isn't an open service, fail 

272 if not auth_info: 

273 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

274 # NOTE: do not translate this string; it's used in a hacky way in the frontend 

275 return unauthenticated_handler("Unauthorized") 

276 else: 

277 # a valid user session was found 

278 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info 

279 

280 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

281 # NOTE: do not translate this string; it's used in a hacky way in the frontend 

282 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

283 

284 # if the user is jailed and this is isn't an open or jailed service, fail 

285 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

286 # NOTE: do not translate this string; it's used in a hacky way in the frontend 

287 return unauthenticated_handler("Permission denied") 

288 

289 handler = continuation(handler_call_details) 

290 if not handler: 

291 raise RuntimeError(f"No handler in '{method}'") 

292 

293 prev_function = handler.unary_unary 

294 if not prev_function: 

295 raise RuntimeError(f"No prev_function in '{method}', {handler}") 

296 

297 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None: 

298 couchers_context: CouchersContext = make_interactive_context( 

299 grpc_context=grpc_context, 

300 user_id=user_id, 

301 is_api_key=is_api_key, 

302 token=token, 

303 ui_language_preference=ui_language_preference, 

304 ) 

305 with session_scope() as session: 

306 try: 

307 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type] 

308 res = cast(Message, _res) 

309 finished = perf_counter_ns() 

310 duration = (finished - start) / 1e6 # ms 

311 _store_log( 

312 method=method, 

313 status_code=None, 

314 duration=duration, 

315 user_id=couchers_context._user_id, 

316 is_api_key=cast(bool, couchers_context._is_api_key), 

317 request=req, 

318 response=res, 

319 traceback=None, 

320 perf_report=None, 

321 ip_address=ip_address, 

322 user_agent=user_agent, 

323 ) 

324 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

325 except Exception as e: 

326 finished = perf_counter_ns() 

327 duration = (finished - start) / 1e6 # ms 

328 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr] 

329 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

330 _store_log( 

331 method=method, 

332 status_code=code, 

333 duration=duration, 

334 user_id=couchers_context._user_id, 

335 is_api_key=cast(bool, couchers_context._is_api_key), 

336 request=req, 

337 response=None, 

338 traceback=traceback, 

339 perf_report=None, 

340 ip_address=ip_address, 

341 user_agent=user_agent, 

342 ) 

343 observe_in_servicer_duration_histogram( 

344 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

345 ) 

346 

347 if not code: 

348 sentry_sdk.set_tag("context", "servicer") 

349 sentry_sdk.set_tag("method", method) 

350 sentry_sdk.capture_exception(e) 

351 

352 raise e 

353 

354 if user_id and not is_api_key: 

355 # Sanity check. If user_id is present, then we should have a token. 

356 if token is None or token_expiry is None: 

357 raise RuntimeError(f"{token=}, {token_expiry=}") 

358 

359 # check the two cookies are in sync & that language preference cookie is correct 

360 if parse_user_id_cookie(headers) != str(user_id): 

361 couchers_context.set_cookies(create_session_cookies(token, user_id, token_expiry)) 

362 if ui_language_preference and ui_language_preference != parse_ui_lang_cookie(headers): 

363 couchers_context.set_cookies(create_lang_cookie(ui_language_preference)) 

364 

365 if not grpc_context.is_active(): 

366 grpc_context.abort(grpc.StatusCode.INTERNAL, "Call cancelled.") 

367 

368 couchers_context._send_cookies() 

369 

370 return res 

371 

372 return grpc.unary_unary_rpc_method_handler( 

373 function_without_couchers_stuff, 

374 request_deserializer=handler.request_deserializer, 

375 response_serializer=handler.response_serializer, 

376 ) 

377 

378 

379class MediaInterceptor(grpc.ServerInterceptor): 

380 """ 

381 Extracts an "Authorization: Bearer <hex>" header and calls the 

382 is_authorized function. Terminates the call with an HTTP error 

383 code if not authorized. 

384 

385 Also adds a session to called APIs. 

386 """ 

387 

388 def __init__(self, is_authorized: Callable[[str], bool]): 

389 self._is_authorized = is_authorized 

390 

391 def intercept_service[T, R]( 

392 self, 

393 continuation: Cont[T, R], 

394 handler_call_details: grpc.HandlerCallDetails, 

395 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

396 handler = continuation(handler_call_details) 

397 if not handler: 

398 raise RuntimeError("No handler") 

399 

400 prev_func = handler.unary_unary 

401 if not prev_func: 

402 raise RuntimeError(f"No prev_function, {handler}") 

403 

404 metadata = dict(handler_call_details.invocation_metadata) 

405 

406 token = parse_api_key(metadata) 

407 

408 if not token or not self._is_authorized(token): 

409 return unauthenticated_handler() 

410 

411 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R: 

412 with session_scope() as session: 

413 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type] 

414 

415 return grpc.unary_unary_rpc_method_handler( 

416 function_without_session, 

417 request_deserializer=handler.request_deserializer, 

418 response_serializer=handler.response_serializer, 

419 ) 

420 

421 

422class OTelInterceptor(grpc.ServerInterceptor): 

423 """ 

424 OpenTelemetry tracing 

425 """ 

426 

427 def __init__(self) -> None: 

428 self.tracer = trace.get_tracer(__name__) 

429 

430 def intercept_service[T, R]( 

431 self, 

432 continuation: Cont[T, R], 

433 handler_call_details: grpc.HandlerCallDetails, 

434 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

435 handler = continuation(handler_call_details) 

436 if not handler: 

437 raise RuntimeError("No handler") 

438 

439 prev_func = handler.unary_unary 

440 if not prev_func: 

441 raise RuntimeError(f"No prev_function, {handler}") 

442 

443 method = handler_call_details.method 

444 

445 # method is of the form "/org.couchers.api.core.API/GetUser" 

446 _, service_name, method_name = method.split("/") 

447 

448 headers = dict(handler_call_details.invocation_metadata) 

449 

450 def tracing_function(request: T, context: grpc.ServicerContext) -> R: 

451 with self.tracer.start_as_current_span("handler") as rollspan: 

452 rollspan.set_attribute("rpc.method_full", method) 

453 rollspan.set_attribute("rpc.service", service_name) 

454 rollspan.set_attribute("rpc.method", method_name) 

455 

456 rollspan.set_attribute("rpc.thread", get_ident()) 

457 rollspan.set_attribute("rpc.pid", getpid()) 

458 

459 res = prev_func(request, context) 

460 

461 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

462 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

463 

464 return res 

465 

466 return grpc.unary_unary_rpc_method_handler( 

467 tracing_function, 

468 request_deserializer=handler.request_deserializer, 

469 response_serializer=handler.response_serializer, 

470 ) 

471 

472 

473class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

474 """ 

475 If the call resulted in a non-gRPC error, this strips away the error details. 

476 

477 It's important to put this first, so that it does not interfere with other interceptors. 

478 """ 

479 

480 def intercept_service[T, R]( 

481 self, 

482 continuation: Cont[T, R], 

483 handler_call_details: grpc.HandlerCallDetails, 

484 ) -> "grpc.RpcMethodHandler[T, R | Never]": 

485 handler = continuation(handler_call_details) 

486 if not handler: 

487 raise RuntimeError("No handler") 

488 

489 prev_func = handler.unary_unary 

490 if not prev_func: 

491 raise RuntimeError(f"No prev_function, {handler}") 

492 

493 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R: 

494 try: 

495 res = prev_func(req, context) 

496 except Exception as e: 

497 code = context.code() # type: ignore[attr-defined] 

498 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

499 if not code: 

500 logger.exception(e) 

501 logger.info("Probably an unknown error! Sanitizing...") 

502 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE) 

503 else: 

504 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

505 raise e 

506 return res 

507 

508 return grpc.unary_unary_rpc_method_handler( 

509 sanitizing_function, 

510 request_deserializer=handler.request_deserializer, 

511 response_serializer=handler.response_serializer, 

512 )