Coverage for app / backend / src / couchers / interceptors.py: 86%
266 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
1import logging
2from collections.abc import Callable, Mapping
3from copy import deepcopy
4from dataclasses import dataclass, field
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, NoReturn, cast, overload
11from zoneinfo import ZoneInfo
13import grpc
14import sentry_sdk
15from google.protobuf.descriptor import ServiceDescriptor
16from google.protobuf.descriptor_pool import DescriptorPool
17from google.protobuf.message import Message
18from opentelemetry import trace
19from sqlalchemy import Function, select
20from sqlalchemy.sql import and_, func
22from couchers.constants import (
23 CALL_CANCELLED_ERROR_MESSAGE,
24 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
25 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
26 NONEXISTENT_API_CALL_ERROR_MESSAGE,
27 PERMISSION_DENIED_ERROR_MESSAGE,
28 UNAUTHORIZED_ERROR_MESSAGE,
29 UNKNOWN_ERROR_MESSAGE,
30)
31from couchers.context import CouchersContext, make_interactive_context, make_media_context
32from couchers.db import session_scope
33from couchers.descriptor_pool import get_descriptor_pool
34from couchers.i18n import LocalizationContext
35from couchers.i18n.locales import DEFAULT_LOCALE
36from couchers.metrics import observe_in_servicer_duration_histogram
37from couchers.models import APICall, User, UserActivity, UserSession
38from couchers.proto import annotations_pb2
39from couchers.proto.annotations_pb2 import AuthLevel
40from couchers.utils import (
41 create_lang_cookie,
42 create_session_cookies,
43 generate_sofa_cookie,
44 now,
45 parse_api_key,
46 parse_session_cookie,
47 parse_sofa_cookie,
48 parse_ui_lang_cookie,
49 parse_user_id_cookie,
50)
52logger = logging.getLogger(__name__)
55@dataclass(frozen=True, slots=True, kw_only=True)
56class UserAuthInfo:
57 """Information about an authenticated user session."""
59 user_id: int
60 is_jailed: bool
61 is_editor: bool
62 is_superuser: bool
63 token_expiry: datetime
64 ui_language_preference: str | None
65 timezone: str | None
66 token: str = field(repr=False)
67 is_api_key: bool
70def _binned_now() -> Function[Any]:
71 return func.date_bin("1 hour", func.now(), "2000-01-01")
74def _try_get_and_update_user_details(
75 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
76) -> UserAuthInfo | None:
77 """
78 Tries to get session and user info corresponding to this token.
80 Also updates the user's last active time, token last active time, and increments API call count.
82 Returns UserAuthInfo if a valid session is found, None otherwise.
83 """
84 if not token:
85 return None
87 with session_scope() as session:
88 result = session.execute(
89 select(User, UserSession, UserActivity)
90 .select_from(UserSession)
91 .join(User, User.id == UserSession.user_id)
92 .outerjoin(
93 UserActivity,
94 and_(
95 UserActivity.user_id == User.id,
96 UserActivity.period == _binned_now(),
97 UserActivity.ip_address == ip_address,
98 UserActivity.user_agent == user_agent,
99 ),
100 )
101 .where(User.is_visible)
102 .where(UserSession.token == token)
103 .where(UserSession.is_valid)
104 .where(UserSession.is_api_key == is_api_key)
105 ).one_or_none()
107 if not result:
108 return None
110 user, user_session, user_activity = result._tuple()
112 # update user last active time if it's been a while
113 if now() - user.last_active > timedelta(minutes=5):
114 user.last_active = func.now()
116 # let's update the token
117 user_session.last_seen = func.now()
118 user_session.api_calls += 1
120 if user_activity:
121 user_activity.api_calls += 1
122 else:
123 session.add(
124 UserActivity(
125 user_id=user.id,
126 period=_binned_now(),
127 ip_address=ip_address,
128 user_agent=user_agent,
129 api_calls=1,
130 )
131 )
133 session.commit()
135 return UserAuthInfo(
136 user_id=user.id,
137 is_jailed=user.is_jailed,
138 is_editor=user.is_editor,
139 is_superuser=user.is_superuser,
140 token_expiry=user_session.expiry,
141 ui_language_preference=user.ui_language_preference,
142 timezone=user.timezone,
143 token=token,
144 is_api_key=is_api_key,
145 )
148def abort_handler[T, R](
149 message: str,
150 status_code: grpc.StatusCode,
151) -> grpc.RpcMethodHandler[T, R]:
152 def f(request: Any, context: CouchersContext) -> NoReturn:
153 context.abort(status_code, message)
155 return grpc.unary_unary_rpc_method_handler(f)
158def unauthenticated_handler[T, R](
159 message: str = UNAUTHORIZED_ERROR_MESSAGE,
160 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
161) -> grpc.RpcMethodHandler[T, R]:
162 return abort_handler(message, status_code)
165@overload
166def _sanitized_bytes(proto: Message) -> bytes: ...
167@overload
168def _sanitized_bytes(proto: None) -> None: ...
169def _sanitized_bytes(proto: Message | None) -> bytes | None:
170 """
171 Remove fields marked sensitive and return serialized bytes
172 """
173 if not proto:
174 return None
176 new_proto = deepcopy(proto)
178 def _sanitize_message(message: Message) -> None:
179 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
180 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
181 message.ClearField(name)
182 if descriptor.message_type:
183 submessage = getattr(message, name)
184 if not submessage:
185 continue
186 if descriptor.is_repeated:
187 for msg in submessage:
188 _sanitize_message(msg)
189 else:
190 _sanitize_message(submessage)
192 _sanitize_message(new_proto)
194 return new_proto.SerializeToString()
197def _store_log(
198 *,
199 method: str,
200 status_code: str | None = None,
201 duration: float,
202 user_id: int | None,
203 is_api_key: bool,
204 request: Message,
205 response: Message | None,
206 traceback: str | None = None,
207 perf_report: str | None = None,
208 ip_address: str | None,
209 user_agent: str | None,
210 sofa: str | None,
211) -> None:
212 req_bytes = _sanitized_bytes(request)
213 res_bytes = _sanitized_bytes(response)
214 with session_scope() as session:
215 response_truncated = False
216 truncate_res_bytes_length = 16 * 1024 # 16 kB
217 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true
218 res_bytes = res_bytes[:truncate_res_bytes_length]
219 response_truncated = True
220 session.add(
221 APICall(
222 is_api_key=is_api_key,
223 method=method,
224 status_code=status_code,
225 duration=duration,
226 user_id=user_id,
227 request=req_bytes,
228 response=res_bytes,
229 response_truncated=response_truncated,
230 traceback=traceback,
231 perf_report=perf_report,
232 ip_address=ip_address,
233 user_agent=user_agent,
234 sofa=sofa,
235 )
236 )
237 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
240type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
243class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
244 """
245 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
247 Sets context.user_id and context.token if authenticated, otherwise
248 terminates the call with an UNAUTHENTICATED error code.
250 2. Makes sure cookies are in sync.
252 3. Injects a session to get a database transaction.
254 4. Measures and logs the time it takes to service each incoming call.
255 """
257 def __init__(self) -> None:
258 self._pool = get_descriptor_pool()
260 def intercept_service[T = Message, R = Message](
261 self,
262 continuation: Cont[T, R],
263 handler_call_details: grpc.HandlerCallDetails,
264 ) -> grpc.RpcMethodHandler[T, R]:
265 start = perf_counter_ns()
267 method = handler_call_details.method
269 try:
270 auth_level = find_auth_level(self._pool, method)
271 except AbortError as ae:
272 return abort_handler(ae.msg, ae.code)
274 try:
275 headers = parse_headers(dict(handler_call_details.invocation_metadata))
276 except BadHeaders:
277 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
279 auth_info = _try_get_and_update_user_details(
280 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent
281 )
283 try:
284 check_permissions(auth_info, auth_level)
285 except AbortError as ae:
286 return unauthenticated_handler(ae.msg, ae.code)
288 if not (handler := continuation(handler_call_details)): 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 raise RuntimeError(f"No handler in '{method}'")
291 if not (prev_function := handler.unary_unary): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 raise RuntimeError(f"No prev_function in '{method}', {handler}")
294 if headers.sofa:
295 sofa = headers.sofa
296 new_sofa_cookie = None
297 else:
298 sofa, new_sofa_cookie = generate_sofa_cookie()
300 loc_context = LocalizationContext(
301 locale=(auth_info.ui_language_preference if auth_info else headers.ui_lang) or DEFAULT_LOCALE,
302 timezone=ZoneInfo((auth_info and auth_info.timezone) or "Etc/UTC"),
303 )
305 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
306 couchers_context = make_interactive_context(
307 grpc_context=grpc_context,
308 user_id=auth_info.user_id if auth_info else None,
309 is_api_key=auth_info.is_api_key if auth_info else False,
310 token=auth_info.token if auth_info else None,
311 localization=loc_context,
312 sofa=sofa,
313 )
315 with session_scope() as session:
316 try:
317 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
318 res = cast(Message, _res)
319 finished = perf_counter_ns()
320 duration = (finished - start) / 1e6 # ms
321 _store_log(
322 method=method,
323 duration=duration,
324 user_id=couchers_context._user_id,
325 is_api_key=cast(bool, couchers_context._is_api_key),
326 request=req,
327 response=res,
328 ip_address=headers.ip_address,
329 user_agent=headers.user_agent,
330 sofa=sofa,
331 )
332 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
333 except Exception as e:
334 finished = perf_counter_ns()
335 duration = (finished - start) / 1e6 # ms
337 if couchers_context._grpc_context: 337 ↛ 341line 337 didn't jump to line 341 because the condition on line 337 was always true
338 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined]
339 code = getattr(context_code, "name", None)
340 else:
341 code = None
343 traceback = "".join(format_exception(type(e), e, e.__traceback__))
344 _store_log(
345 method=method,
346 status_code=code,
347 duration=duration,
348 user_id=couchers_context._user_id,
349 is_api_key=cast(bool, couchers_context._is_api_key),
350 request=req,
351 response=None,
352 traceback=traceback,
353 ip_address=headers.ip_address,
354 user_agent=headers.user_agent,
355 sofa=sofa,
356 )
357 observe_in_servicer_duration_histogram(
358 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
359 )
361 if not code:
362 sentry_sdk.set_tag("context", "servicer")
363 sentry_sdk.set_tag("method", method)
364 sentry_sdk.capture_exception(e)
366 raise e
368 if auth_info and not auth_info.is_api_key:
369 # check the two cookies are in sync & that language preference cookie is correct
370 if headers.user_id != str(auth_info.user_id): 370 ↛ 374line 370 didn't jump to line 374 because the condition on line 370 was always true
371 couchers_context.set_cookies(
372 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry)
373 )
374 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang:
375 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
377 if new_sofa_cookie:
378 couchers_context.set_cookies([new_sofa_cookie])
380 if not grpc_context.is_active(): 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
383 couchers_context._send_cookies()
385 return res
387 return grpc.unary_unary_rpc_method_handler(
388 function_without_couchers_stuff,
389 request_deserializer=handler.request_deserializer,
390 response_serializer=handler.response_serializer,
391 )
394@dataclass(frozen=True, slots=True, kw_only=True)
395class CouchersHeaders:
396 token: str | None = field(repr=False)
397 is_api_key: bool
398 ip_address: str | None
399 user_agent: str | None
400 ui_lang: str | None
401 user_id: str | None
402 sofa: str | None
405def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders:
406 if "cookie" in headers and "authorization" in headers:
407 # for security reasons, only one of "cookie" or "authorization" can be present
408 raise BadHeaders("Both cookies and authorization are present in headers")
409 elif "cookie" in headers:
410 # the session token is passed in cookies, i.e., in the `cookie` header
411 token, is_api_key = parse_session_cookie(headers), False
412 elif "authorization" in headers:
413 # the session token is passed in the `authorization` header
414 token, is_api_key = parse_api_key(headers), True
415 else:
416 # no session found
417 token, is_api_key = None, False
419 ip_address = headers.get("x-couchers-real-ip")
420 user_agent = headers.get("user-agent")
422 ui_lang = parse_ui_lang_cookie(headers)
423 user_id = parse_user_id_cookie(headers)
424 sofa = parse_sofa_cookie(headers)
426 return CouchersHeaders(
427 token=token,
428 is_api_key=is_api_key,
429 ip_address=ip_address if isinstance(ip_address, str) else None,
430 user_agent=user_agent if isinstance(user_agent, str) else None,
431 ui_lang=ui_lang,
432 user_id=user_id,
433 sofa=sofa,
434 )
437class BadHeaders(Exception):
438 pass
441class AbortError(Exception):
442 def __init__(self, msg: str, code: grpc.StatusCode):
443 self.msg = msg
444 self.code = code
447def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType:
448 # method is of the form "/org.couchers.api.core.API/GetUser"
449 _, service_name, method_name = method.split("/")
451 try:
452 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
453 service_options = service.GetOptions()
454 except KeyError:
455 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None
457 level = service_options.Extensions[annotations_pb2.auth_level]
459 validate_auth_level(level)
461 return level
464def validate_auth_level(auth_level: AuthLevel.ValueType) -> None:
465 # if unknown auth level, then it wasn't set and something's wrong
466 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
467 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
469 if auth_level not in { 469 ↛ 476line 469 didn't jump to line 476 because the condition on line 469 was never true
470 annotations_pb2.AUTH_LEVEL_OPEN,
471 annotations_pb2.AUTH_LEVEL_JAILED,
472 annotations_pb2.AUTH_LEVEL_SECURE,
473 annotations_pb2.AUTH_LEVEL_EDITOR,
474 annotations_pb2.AUTH_LEVEL_ADMIN,
475 }:
476 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
479def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None:
480 if not auth_info:
481 # if this isn't an open service, fail
482 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
483 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
484 else:
485 # a valid user session was found - check permissions
486 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
487 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
489 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
490 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
492 # if the user is jailed and this isn't an open or jailed service, fail
493 if auth_info.is_jailed and auth_level not in [
494 annotations_pb2.AUTH_LEVEL_OPEN,
495 annotations_pb2.AUTH_LEVEL_JAILED,
496 ]:
497 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
500class MediaInterceptor(grpc.ServerInterceptor):
501 """
502 Extracts an "Authorization: Bearer <hex>" header and calls the
503 is_authorized function. Terminates the call with an HTTP error
504 code if not authorized.
506 Also adds a session to called APIs.
507 """
509 def __init__(self, is_authorized: Callable[[str], bool]):
510 self._is_authorized = is_authorized
512 def intercept_service[T, R](
513 self,
514 continuation: Cont[T, R],
515 handler_call_details: grpc.HandlerCallDetails,
516 ) -> grpc.RpcMethodHandler[T, R]:
517 handler = continuation(handler_call_details)
518 if not handler: 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true
519 raise RuntimeError("No handler")
521 prev_func = handler.unary_unary
522 if not prev_func: 522 ↛ 523line 522 didn't jump to line 523 because the condition on line 522 was never true
523 raise RuntimeError(f"No prev_function, {handler}")
525 metadata = dict(handler_call_details.invocation_metadata)
527 token = parse_api_key(metadata)
529 if not token or not self._is_authorized(token): 529 ↛ 530line 529 didn't jump to line 530 because the condition on line 529 was never true
530 return unauthenticated_handler()
532 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
533 with session_scope() as session:
534 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
536 return grpc.unary_unary_rpc_method_handler(
537 function_without_session,
538 request_deserializer=handler.request_deserializer,
539 response_serializer=handler.response_serializer,
540 )
543class OTelInterceptor(grpc.ServerInterceptor):
544 """
545 OpenTelemetry tracing
546 """
548 def __init__(self) -> None:
549 self.tracer = trace.get_tracer(__name__)
551 def intercept_service[T, R](
552 self,
553 continuation: Cont[T, R],
554 handler_call_details: grpc.HandlerCallDetails,
555 ) -> grpc.RpcMethodHandler[T, R]:
556 handler = continuation(handler_call_details)
557 if not handler:
558 raise RuntimeError("No handler")
560 prev_func = handler.unary_unary
561 if not prev_func:
562 raise RuntimeError(f"No prev_function, {handler}")
564 method = handler_call_details.method
566 # method is of the form "/org.couchers.api.core.API/GetUser"
567 _, service_name, method_name = method.split("/")
569 headers = dict(handler_call_details.invocation_metadata)
571 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
572 with self.tracer.start_as_current_span("handler") as rollspan:
573 rollspan.set_attribute("rpc.method_full", method)
574 rollspan.set_attribute("rpc.service", service_name)
575 rollspan.set_attribute("rpc.method", method_name)
577 rollspan.set_attribute("rpc.thread", get_ident())
578 rollspan.set_attribute("rpc.pid", getpid())
580 res = prev_func(request, context)
582 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
583 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
585 return res
587 return grpc.unary_unary_rpc_method_handler(
588 tracing_function,
589 request_deserializer=handler.request_deserializer,
590 response_serializer=handler.response_serializer,
591 )
594class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
595 """
596 If the call resulted in a non-gRPC error, this strips away the error details.
598 It's important to put this first, so that it does not interfere with other interceptors.
599 """
601 def intercept_service[T, R](
602 self,
603 continuation: Cont[T, R],
604 handler_call_details: grpc.HandlerCallDetails,
605 ) -> grpc.RpcMethodHandler[T, R]:
606 handler = continuation(handler_call_details)
607 if not handler: 607 ↛ 608line 607 didn't jump to line 608 because the condition on line 607 was never true
608 raise RuntimeError("No handler")
610 prev_func = handler.unary_unary
611 if not prev_func: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true
612 raise RuntimeError(f"No prev_function, {handler}")
614 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
615 try:
616 res = prev_func(req, context)
617 except Exception as e:
618 code = context.code() # type: ignore[attr-defined]
619 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
620 if not code:
621 logger.exception(e)
622 logger.info("Probably an unknown error! Sanitizing...")
623 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
624 else:
625 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
626 raise e
627 return res
629 return grpc.unary_unary_rpc_method_handler(
630 sanitizing_function,
631 request_deserializer=handler.request_deserializer,
632 response_serializer=handler.response_serializer,
633 )