Coverage for src / couchers / interceptors.py: 84%
225 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 11:17 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-02 11:17 +0000
1import logging
2from collections.abc import Callable
3from copy import deepcopy
4from dataclasses import dataclass
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, NoReturn, cast
12import grpc
13import sentry_sdk
14from google.protobuf.descriptor import ServiceDescriptor
15from google.protobuf.message import Message
16from opentelemetry import trace
17from sqlalchemy import Function, select
18from sqlalchemy.sql import and_, func
20from couchers.constants import (
21 CALL_CANCELLED_ERROR_MESSAGE,
22 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
23 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
24 NONEXISTENT_API_CALL_ERROR_MESSAGE,
25 PERMISSION_DENIED_ERROR_MESSAGE,
26 UNAUTHORIZED_ERROR_MESSAGE,
27 UNKNOWN_ERROR_MESSAGE,
28)
29from couchers.context import CouchersContext, make_interactive_context, make_media_context
30from couchers.db import session_scope
31from couchers.descriptor_pool import get_descriptor_pool
32from couchers.metrics import observe_in_servicer_duration_histogram
33from couchers.models import APICall, User, UserActivity, UserSession
34from couchers.proto import annotations_pb2
35from couchers.utils import (
36 create_lang_cookie,
37 create_session_cookies,
38 now,
39 parse_api_key,
40 parse_session_cookie,
41 parse_ui_lang_cookie,
42 parse_user_id_cookie,
43)
45logger = logging.getLogger(__name__)
48@dataclass(frozen=True, slots=True)
49class UserAuthInfo:
50 """
51 Information about an authenticated user session.
53 Returned by _try_get_and_update_user_details when a valid session is found.
54 """
56 user_id: int
57 is_jailed: bool
58 is_editor: bool
59 is_superuser: bool
60 token_expiry: datetime
61 ui_language_preference: str | None
64def _binned_now() -> Function[Any]:
65 return func.date_bin("1 hour", func.now(), "2000-01-01")
68def _try_get_and_update_user_details(
69 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
70) -> UserAuthInfo | None:
71 """
72 Tries to get session and user info corresponding to this token.
74 Also updates the user's last active time, token last active time, and increments API call count.
76 Returns UserAuthInfo if valid session found, None otherwise.
77 """
78 if not token:
79 return None
81 with session_scope() as session:
82 result = session.execute(
83 select(User, UserSession, UserActivity)
84 .join(User, User.id == UserSession.user_id)
85 .outerjoin(
86 UserActivity,
87 and_(
88 UserActivity.user_id == User.id,
89 UserActivity.period == _binned_now(),
90 UserActivity.ip_address == ip_address,
91 UserActivity.user_agent == user_agent,
92 ),
93 )
94 .where(User.is_visible)
95 .where(UserSession.token == token)
96 .where(UserSession.is_valid)
97 .where(UserSession.is_api_key == is_api_key)
98 ).one_or_none()
100 if not result:
101 return None
102 else:
103 user, user_session, user_activity = result
105 # update user last active time if it's been a while
106 if now() - user.last_active > timedelta(minutes=5):
107 user.last_active = func.now()
109 # let's update the token
110 user_session.last_seen = func.now()
111 user_session.api_calls += 1
113 if user_activity:
114 user_activity.api_calls += 1
115 else:
116 session.add(
117 UserActivity(
118 user_id=user.id,
119 period=_binned_now(),
120 ip_address=ip_address,
121 user_agent=user_agent,
122 api_calls=1,
123 )
124 )
126 session.commit()
128 return UserAuthInfo(
129 user_id=user.id,
130 is_jailed=user.is_jailed,
131 is_editor=user.is_editor,
132 is_superuser=user.is_superuser,
133 token_expiry=user_session.expiry,
134 ui_language_preference=user.ui_language_preference,
135 )
138def abort_handler[T, R](
139 message: str,
140 status_code: grpc.StatusCode,
141) -> grpc.RpcMethodHandler[T, R]:
142 def f(request: Any, context: CouchersContext) -> NoReturn:
143 context.abort(status_code, message)
145 return grpc.unary_unary_rpc_method_handler(f)
148def unauthenticated_handler[T, R](
149 message: str = UNAUTHORIZED_ERROR_MESSAGE,
150 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
151) -> grpc.RpcMethodHandler[T, R]:
152 return abort_handler(message, status_code)
155def _sanitized_bytes(proto: Message | None) -> bytes | None:
156 """
157 Remove fields marked sensitive and return serialized bytes
158 """
159 if not proto:
160 return None
162 new_proto = deepcopy(proto)
164 def _sanitize_message(message: Message) -> None:
165 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
166 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
167 message.ClearField(name)
168 if descriptor.message_type:
169 submessage = getattr(message, name)
170 if not submessage:
171 continue
172 if descriptor.is_repeated:
173 for msg in submessage:
174 _sanitize_message(msg)
175 else:
176 _sanitize_message(submessage)
178 _sanitize_message(new_proto)
180 return new_proto.SerializeToString()
183def _store_log(
184 *,
185 method: str,
186 status_code: grpc.StatusCode | None,
187 duration: float,
188 user_id: int | None,
189 is_api_key: bool,
190 request: Message,
191 response: Message | None,
192 traceback: str | None,
193 perf_report: str | None,
194 ip_address: str | None,
195 user_agent: str | None,
196) -> None:
197 req_bytes = _sanitized_bytes(request)
198 res_bytes = _sanitized_bytes(response)
199 with session_scope() as session:
200 response_truncated = False
201 truncate_res_bytes_length = 16 * 1024 # 16 kB
202 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true
203 res_bytes = res_bytes[:truncate_res_bytes_length]
204 response_truncated = True
205 session.add(
206 APICall(
207 is_api_key=is_api_key,
208 method=method,
209 status_code=status_code,
210 duration=duration,
211 user_id=user_id,
212 request=req_bytes,
213 response=res_bytes,
214 response_truncated=response_truncated,
215 traceback=traceback,
216 perf_report=perf_report,
217 ip_address=ip_address,
218 user_agent=user_agent,
219 )
220 )
221 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
224type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
227class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
228 """
229 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
231 Sets context.user_id and context.token if authenticated, otherwise
232 terminates the call with an UNAUTHENTICATED error code.
234 2. Makes sure cookies are in sync.
236 3. Injects a session to get a database transaction.
238 4. Measures and logs the time it takes to service each incoming call.
239 """
241 def __init__(self) -> None:
242 self._pool = get_descriptor_pool()
244 def intercept_service[T = Message, R = Message](
245 self,
246 continuation: Cont[T, R],
247 handler_call_details: grpc.HandlerCallDetails,
248 ) -> grpc.RpcMethodHandler[T, R]:
249 start = perf_counter_ns()
251 method = handler_call_details.method
252 # method is of the form "/org.couchers.api.core.API/GetUser"
253 _, service_name, method_name = method.split("/")
255 try:
256 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
257 service_options = service.GetOptions()
258 except KeyError:
259 return abort_handler(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED)
261 auth_level = service_options.Extensions[annotations_pb2.auth_level]
263 # if unknown auth level, then it wasn't set and something's wrong
264 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
265 return abort_handler(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
267 assert auth_level in [
268 annotations_pb2.AUTH_LEVEL_OPEN,
269 annotations_pb2.AUTH_LEVEL_JAILED,
270 annotations_pb2.AUTH_LEVEL_SECURE,
271 annotations_pb2.AUTH_LEVEL_EDITOR,
272 annotations_pb2.AUTH_LEVEL_ADMIN,
273 ]
275 headers = dict(handler_call_details.invocation_metadata)
277 if "cookie" in headers and "authorization" in headers:
278 # for security reasons, only one of "cookie" or "authorization" can be present
279 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
280 elif "cookie" in headers:
281 # the session token is passed in cookies, i.e. in the `cookie` header
282 token, is_api_key = parse_session_cookie(headers), False
283 elif "authorization" in headers:
284 # the session token is passed in the `authorization` header
285 token, is_api_key = parse_api_key(headers), True
286 else:
287 # no session found
288 token, is_api_key = None, False
290 ip_address = cast(str | None, headers.get("x-couchers-real-ip"))
291 user_agent = cast(str | None, headers.get("user-agent"))
293 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
295 if not auth_info:
296 # Invalid or no session - clear credentials
297 token = None
298 is_api_key = False
300 # if this isn't an open service, fail
301 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
302 return unauthenticated_handler(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
303 else:
304 # a valid user session was found - check permissions
305 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
306 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
308 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
309 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
311 # if the user is jailed and this is isn't an open or jailed service, fail
312 if auth_info.is_jailed and auth_level not in [
313 annotations_pb2.AUTH_LEVEL_OPEN,
314 annotations_pb2.AUTH_LEVEL_JAILED,
315 ]:
316 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
318 handler = continuation(handler_call_details)
319 if not handler: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true
320 raise RuntimeError(f"No handler in '{method}'")
322 prev_function = handler.unary_unary
323 if not prev_function: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true
324 raise RuntimeError(f"No prev_function in '{method}', {handler}")
326 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
327 couchers_context: CouchersContext = make_interactive_context(
328 grpc_context=grpc_context,
329 user_id=auth_info.user_id if auth_info else None,
330 is_api_key=is_api_key,
331 token=token,
332 ui_language_preference=auth_info.ui_language_preference if auth_info else None,
333 )
334 with session_scope() as session:
335 try:
336 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
337 res = cast(Message, _res)
338 finished = perf_counter_ns()
339 duration = (finished - start) / 1e6 # ms
340 _store_log(
341 method=method,
342 status_code=None,
343 duration=duration,
344 user_id=couchers_context._user_id,
345 is_api_key=cast(bool, couchers_context._is_api_key),
346 request=req,
347 response=res,
348 traceback=None,
349 perf_report=None,
350 ip_address=ip_address,
351 user_agent=user_agent,
352 )
353 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
354 except Exception as e:
355 finished = perf_counter_ns()
356 duration = (finished - start) / 1e6 # ms
357 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr]
358 traceback = "".join(format_exception(type(e), e, e.__traceback__))
359 _store_log(
360 method=method,
361 status_code=code,
362 duration=duration,
363 user_id=couchers_context._user_id,
364 is_api_key=cast(bool, couchers_context._is_api_key),
365 request=req,
366 response=None,
367 traceback=traceback,
368 perf_report=None,
369 ip_address=ip_address,
370 user_agent=user_agent,
371 )
372 observe_in_servicer_duration_histogram(
373 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
374 )
376 if not code:
377 sentry_sdk.set_tag("context", "servicer")
378 sentry_sdk.set_tag("method", method)
379 sentry_sdk.capture_exception(e)
381 raise e
383 if auth_info and not is_api_key:
384 # Sanity check. If auth_info is present, then we should have a token.
385 if token is None: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true
386 raise RuntimeError(f"{token=}, {auth_info.token_expiry=}")
388 # check the two cookies are in sync & that language preference cookie is correct
389 if parse_user_id_cookie(headers) != str(auth_info.user_id): 389 ↛ 393line 389 didn't jump to line 393 because the condition on line 389 was always true
390 couchers_context.set_cookies(
391 create_session_cookies(token, auth_info.user_id, auth_info.token_expiry)
392 )
393 if auth_info.ui_language_preference and auth_info.ui_language_preference != parse_ui_lang_cookie(
394 headers
395 ):
396 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
398 if not grpc_context.is_active(): 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true
399 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
401 couchers_context._send_cookies()
403 return res
405 return grpc.unary_unary_rpc_method_handler(
406 function_without_couchers_stuff,
407 request_deserializer=handler.request_deserializer,
408 response_serializer=handler.response_serializer,
409 )
412class MediaInterceptor(grpc.ServerInterceptor):
413 """
414 Extracts an "Authorization: Bearer <hex>" header and calls the
415 is_authorized function. Terminates the call with an HTTP error
416 code if not authorized.
418 Also adds a session to called APIs.
419 """
421 def __init__(self, is_authorized: Callable[[str], bool]):
422 self._is_authorized = is_authorized
424 def intercept_service[T, R](
425 self,
426 continuation: Cont[T, R],
427 handler_call_details: grpc.HandlerCallDetails,
428 ) -> grpc.RpcMethodHandler[T, R]:
429 handler = continuation(handler_call_details)
430 if not handler: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true
431 raise RuntimeError("No handler")
433 prev_func = handler.unary_unary
434 if not prev_func: 434 ↛ 435line 434 didn't jump to line 435 because the condition on line 434 was never true
435 raise RuntimeError(f"No prev_function, {handler}")
437 metadata = dict(handler_call_details.invocation_metadata)
439 token = parse_api_key(metadata)
441 if not token or not self._is_authorized(token): 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true
442 return unauthenticated_handler()
444 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
445 with session_scope() as session:
446 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
448 return grpc.unary_unary_rpc_method_handler(
449 function_without_session,
450 request_deserializer=handler.request_deserializer,
451 response_serializer=handler.response_serializer,
452 )
455class OTelInterceptor(grpc.ServerInterceptor):
456 """
457 OpenTelemetry tracing
458 """
460 def __init__(self) -> None:
461 self.tracer = trace.get_tracer(__name__)
463 def intercept_service[T, R](
464 self,
465 continuation: Cont[T, R],
466 handler_call_details: grpc.HandlerCallDetails,
467 ) -> grpc.RpcMethodHandler[T, R]:
468 handler = continuation(handler_call_details)
469 if not handler:
470 raise RuntimeError("No handler")
472 prev_func = handler.unary_unary
473 if not prev_func:
474 raise RuntimeError(f"No prev_function, {handler}")
476 method = handler_call_details.method
478 # method is of the form "/org.couchers.api.core.API/GetUser"
479 _, service_name, method_name = method.split("/")
481 headers = dict(handler_call_details.invocation_metadata)
483 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
484 with self.tracer.start_as_current_span("handler") as rollspan:
485 rollspan.set_attribute("rpc.method_full", method)
486 rollspan.set_attribute("rpc.service", service_name)
487 rollspan.set_attribute("rpc.method", method_name)
489 rollspan.set_attribute("rpc.thread", get_ident())
490 rollspan.set_attribute("rpc.pid", getpid())
492 res = prev_func(request, context)
494 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
495 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
497 return res
499 return grpc.unary_unary_rpc_method_handler(
500 tracing_function,
501 request_deserializer=handler.request_deserializer,
502 response_serializer=handler.response_serializer,
503 )
506class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
507 """
508 If the call resulted in a non-gRPC error, this strips away the error details.
510 It's important to put this first, so that it does not interfere with other interceptors.
511 """
513 def intercept_service[T, R](
514 self,
515 continuation: Cont[T, R],
516 handler_call_details: grpc.HandlerCallDetails,
517 ) -> grpc.RpcMethodHandler[T, R]:
518 handler = continuation(handler_call_details)
519 if not handler: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true
520 raise RuntimeError("No handler")
522 prev_func = handler.unary_unary
523 if not prev_func: 523 ↛ 524line 523 didn't jump to line 524 because the condition on line 523 was never true
524 raise RuntimeError(f"No prev_function, {handler}")
526 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
527 try:
528 res = prev_func(req, context)
529 except Exception as e:
530 code = context.code() # type: ignore[attr-defined]
531 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
532 if not code:
533 logger.exception(e)
534 logger.info("Probably an unknown error! Sanitizing...")
535 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
536 else:
537 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
538 raise e
539 return res
541 return grpc.unary_unary_rpc_method_handler(
542 sanitizing_function,
543 request_deserializer=handler.request_deserializer,
544 response_serializer=handler.response_serializer,
545 )