Coverage for app / backend / src / couchers / interceptors.py: 86%
266 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
1import logging
2from collections.abc import Callable, Mapping
3from copy import deepcopy
4from dataclasses import dataclass, field
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, NoReturn, cast, overload
11from zoneinfo import ZoneInfo
13import grpc
14import sentry_sdk
15from google.protobuf.descriptor import ServiceDescriptor
16from google.protobuf.descriptor_pool import DescriptorPool
17from google.protobuf.message import Message
18from opentelemetry import trace
19from sqlalchemy import Function, literal_column, select
20from sqlalchemy.sql import and_, func
22from couchers.constants import (
23 CALL_CANCELLED_ERROR_MESSAGE,
24 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
25 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
26 NONEXISTENT_API_CALL_ERROR_MESSAGE,
27 PERMISSION_DENIED_ERROR_MESSAGE,
28 UNAUTHORIZED_ERROR_MESSAGE,
29 UNKNOWN_ERROR_MESSAGE,
30)
31from couchers.context import CouchersContext, make_interactive_context, make_media_context
32from couchers.db import session_scope
33from couchers.descriptor_pool import get_descriptor_pool
34from couchers.i18n import LocalizationContext
35from couchers.i18n.locales import DEFAULT_LOCALE
36from couchers.metrics import observe_in_servicer_duration_histogram
37from couchers.models import APICall, User, UserActivity, UserSession
38from couchers.proto import annotations_pb2
39from couchers.proto.annotations_pb2 import AuthLevel
40from couchers.utils import (
41 create_lang_cookie,
42 create_session_cookies,
43 generate_sofa_cookie,
44 now,
45 parse_api_key,
46 parse_session_cookie,
47 parse_sofa_cookie,
48 parse_ui_lang_cookie,
49 parse_user_id_cookie,
50)
52logger = logging.getLogger(__name__)
55@dataclass(frozen=True, slots=True, kw_only=True)
56class UserAuthInfo:
57 """Information about an authenticated user session."""
59 user_id: int
60 is_jailed: bool
61 is_editor: bool
62 is_superuser: bool
63 token_expiry: datetime
64 ui_language_preference: str | None
65 timezone: str | None
66 token: str = field(repr=False)
67 is_api_key: bool
70def _binned_now() -> Function[Any]:
71 return func.date_bin(
72 literal_column("interval '1 hour'"),
73 func.now(),
74 literal_column("'2000-01-01'::timestamptz"),
75 )
78def _try_get_and_update_user_details(
79 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
80) -> UserAuthInfo | None:
81 """
82 Tries to get session and user info corresponding to this token.
84 Also updates the user's last active time, token last active time, and increments API call count.
86 Returns UserAuthInfo if a valid session is found, None otherwise.
87 """
88 if not token:
89 return None
91 with session_scope() as session:
92 result = session.execute(
93 select(User, UserSession, UserActivity)
94 .select_from(UserSession)
95 .join(User, User.id == UserSession.user_id)
96 .outerjoin(
97 UserActivity,
98 and_(
99 UserActivity.user_id == User.id,
100 UserActivity.period == _binned_now(),
101 UserActivity.ip_address == ip_address,
102 UserActivity.user_agent == user_agent,
103 ),
104 )
105 .where(User.is_visible)
106 .where(UserSession.token == token)
107 .where(UserSession.is_valid)
108 .where(UserSession.is_api_key == is_api_key)
109 ).one_or_none()
111 if not result:
112 return None
114 user, user_session, user_activity = result._tuple()
116 # update user last active time if it's been a while
117 if now() - user.last_active > timedelta(minutes=5):
118 user.last_active = func.now()
120 # let's update the token
121 user_session.last_seen = func.now()
122 user_session.api_calls += 1
124 if user_activity:
125 user_activity.api_calls += 1
126 else:
127 session.add(
128 UserActivity(
129 user_id=user.id,
130 period=_binned_now(),
131 ip_address=ip_address,
132 user_agent=user_agent,
133 api_calls=1,
134 )
135 )
137 session.commit()
139 return UserAuthInfo(
140 user_id=user.id,
141 is_jailed=user.is_jailed,
142 is_editor=user.is_editor,
143 is_superuser=user.is_superuser,
144 token_expiry=user_session.expiry,
145 ui_language_preference=user.ui_language_preference,
146 timezone=user.timezone,
147 token=token,
148 is_api_key=is_api_key,
149 )
152def abort_handler[T, R](
153 message: str,
154 status_code: grpc.StatusCode,
155) -> grpc.RpcMethodHandler[T, R]:
156 def f(request: Any, context: CouchersContext) -> NoReturn:
157 context.abort(status_code, message)
159 return grpc.unary_unary_rpc_method_handler(f)
162def unauthenticated_handler[T, R](
163 message: str = UNAUTHORIZED_ERROR_MESSAGE,
164 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
165) -> grpc.RpcMethodHandler[T, R]:
166 return abort_handler(message, status_code)
169@overload
170def _sanitized_bytes(proto: Message) -> bytes: ...
171@overload
172def _sanitized_bytes(proto: None) -> None: ...
173def _sanitized_bytes(proto: Message | None) -> bytes | None:
174 """
175 Remove fields marked sensitive and return serialized bytes
176 """
177 if not proto:
178 return None
180 new_proto = deepcopy(proto)
182 def _sanitize_message(message: Message) -> None:
183 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
184 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
185 message.ClearField(name)
186 if descriptor.message_type:
187 submessage = getattr(message, name)
188 if not submessage:
189 continue
190 if descriptor.is_repeated:
191 for msg in submessage:
192 _sanitize_message(msg)
193 else:
194 _sanitize_message(submessage)
196 _sanitize_message(new_proto)
198 return new_proto.SerializeToString()
201def _store_log(
202 *,
203 method: str,
204 status_code: str | None = None,
205 duration: float,
206 user_id: int | None,
207 is_api_key: bool,
208 request: Message,
209 response: Message | None,
210 traceback: str | None = None,
211 perf_report: str | None = None,
212 ip_address: str | None,
213 user_agent: str | None,
214 sofa: str | None,
215) -> None:
216 req_bytes = _sanitized_bytes(request)
217 res_bytes = _sanitized_bytes(response)
218 with session_scope() as session:
219 response_truncated = False
220 truncate_res_bytes_length = 16 * 1024 # 16 kB
221 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 res_bytes = res_bytes[:truncate_res_bytes_length]
223 response_truncated = True
224 session.add(
225 APICall(
226 is_api_key=is_api_key,
227 method=method,
228 status_code=status_code,
229 duration=duration,
230 user_id=user_id,
231 request=req_bytes,
232 response=res_bytes,
233 response_truncated=response_truncated,
234 traceback=traceback,
235 perf_report=perf_report,
236 ip_address=ip_address,
237 user_agent=user_agent,
238 sofa=sofa,
239 )
240 )
241 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
244type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
247class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
248 """
249 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
251 Sets context.user_id and context.token if authenticated, otherwise
252 terminates the call with an UNAUTHENTICATED error code.
254 2. Makes sure cookies are in sync.
256 3. Injects a session to get a database transaction.
258 4. Measures and logs the time it takes to service each incoming call.
259 """
261 def __init__(self) -> None:
262 self._pool = get_descriptor_pool()
264 def intercept_service[T = Message, R = Message](
265 self,
266 continuation: Cont[T, R],
267 handler_call_details: grpc.HandlerCallDetails,
268 ) -> grpc.RpcMethodHandler[T, R]:
269 start = perf_counter_ns()
271 method = handler_call_details.method
273 try:
274 auth_level = find_auth_level(self._pool, method)
275 except AbortError as ae:
276 return abort_handler(ae.msg, ae.code)
278 try:
279 headers = parse_headers(dict(handler_call_details.invocation_metadata))
280 except BadHeaders:
281 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
283 auth_info = _try_get_and_update_user_details(
284 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent
285 )
287 try:
288 check_permissions(auth_info, auth_level)
289 except AbortError as ae:
290 return unauthenticated_handler(ae.msg, ae.code)
292 if not (handler := continuation(handler_call_details)): 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true
293 raise RuntimeError(f"No handler in '{method}'")
295 if not (prev_function := handler.unary_unary): 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true
296 raise RuntimeError(f"No prev_function in '{method}', {handler}")
298 if headers.sofa:
299 sofa = headers.sofa
300 new_sofa_cookie = None
301 else:
302 sofa, new_sofa_cookie = generate_sofa_cookie()
304 loc_context = LocalizationContext(
305 locale=(auth_info.ui_language_preference if auth_info else headers.ui_lang) or DEFAULT_LOCALE,
306 timezone=ZoneInfo((auth_info and auth_info.timezone) or "Etc/UTC"),
307 )
309 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
310 couchers_context = make_interactive_context(
311 grpc_context=grpc_context,
312 user_id=auth_info.user_id if auth_info else None,
313 is_api_key=auth_info.is_api_key if auth_info else False,
314 token=auth_info.token if auth_info else None,
315 localization=loc_context,
316 sofa=sofa,
317 )
319 with session_scope() as session:
320 try:
321 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
322 res = cast(Message, _res)
323 finished = perf_counter_ns()
324 duration = (finished - start) / 1e6 # ms
325 _store_log(
326 method=method,
327 duration=duration,
328 user_id=couchers_context._user_id,
329 is_api_key=cast(bool, couchers_context._is_api_key),
330 request=req,
331 response=res,
332 ip_address=headers.ip_address,
333 user_agent=headers.user_agent,
334 sofa=sofa,
335 )
336 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
337 except Exception as e:
338 finished = perf_counter_ns()
339 duration = (finished - start) / 1e6 # ms
341 if couchers_context._grpc_context: 341 ↛ 345line 341 didn't jump to line 345 because the condition on line 341 was always true
342 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined]
343 code = getattr(context_code, "name", None)
344 else:
345 code = None
347 traceback = "".join(format_exception(type(e), e, e.__traceback__))
348 _store_log(
349 method=method,
350 status_code=code,
351 duration=duration,
352 user_id=couchers_context._user_id,
353 is_api_key=cast(bool, couchers_context._is_api_key),
354 request=req,
355 response=None,
356 traceback=traceback,
357 ip_address=headers.ip_address,
358 user_agent=headers.user_agent,
359 sofa=sofa,
360 )
361 observe_in_servicer_duration_histogram(
362 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
363 )
365 if not code:
366 sentry_sdk.set_tag("context", "servicer")
367 sentry_sdk.set_tag("method", method)
368 sentry_sdk.capture_exception(e)
370 raise e
372 if auth_info and not auth_info.is_api_key:
373 # check the two cookies are in sync & that language preference cookie is correct
374 if headers.user_id != str(auth_info.user_id): 374 ↛ 378line 374 didn't jump to line 378 because the condition on line 374 was always true
375 couchers_context.set_cookies(
376 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry)
377 )
378 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang:
379 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
381 if new_sofa_cookie:
382 couchers_context.set_cookies([new_sofa_cookie])
384 if not grpc_context.is_active(): 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true
385 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
387 couchers_context._send_cookies()
389 return res
391 return grpc.unary_unary_rpc_method_handler(
392 function_without_couchers_stuff,
393 request_deserializer=handler.request_deserializer,
394 response_serializer=handler.response_serializer,
395 )
398@dataclass(frozen=True, slots=True, kw_only=True)
399class CouchersHeaders:
400 token: str | None = field(repr=False)
401 is_api_key: bool
402 ip_address: str | None
403 user_agent: str | None
404 ui_lang: str | None
405 user_id: str | None
406 sofa: str | None
409def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders:
410 if "cookie" in headers and "authorization" in headers:
411 # for security reasons, only one of "cookie" or "authorization" can be present
412 raise BadHeaders("Both cookies and authorization are present in headers")
413 elif "cookie" in headers:
414 # the session token is passed in cookies, i.e., in the `cookie` header
415 token, is_api_key = parse_session_cookie(headers), False
416 elif "authorization" in headers:
417 # the session token is passed in the `authorization` header
418 token, is_api_key = parse_api_key(headers), True
419 else:
420 # no session found
421 token, is_api_key = None, False
423 ip_address = headers.get("x-couchers-real-ip")
424 user_agent = headers.get("user-agent")
426 ui_lang = parse_ui_lang_cookie(headers)
427 user_id = parse_user_id_cookie(headers)
428 sofa = parse_sofa_cookie(headers)
430 return CouchersHeaders(
431 token=token,
432 is_api_key=is_api_key,
433 ip_address=ip_address if isinstance(ip_address, str) else None,
434 user_agent=user_agent if isinstance(user_agent, str) else None,
435 ui_lang=ui_lang,
436 user_id=user_id,
437 sofa=sofa,
438 )
441class BadHeaders(Exception):
442 pass
445class AbortError(Exception):
446 def __init__(self, msg: str, code: grpc.StatusCode):
447 self.msg = msg
448 self.code = code
451def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType:
452 # method is of the form "/org.couchers.api.core.API/GetUser"
453 _, service_name, method_name = method.split("/")
455 try:
456 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
457 service_options = service.GetOptions()
458 except KeyError:
459 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None
461 level = service_options.Extensions[annotations_pb2.auth_level]
463 validate_auth_level(level)
465 return level
468def validate_auth_level(auth_level: AuthLevel.ValueType) -> None:
469 # if unknown auth level, then it wasn't set and something's wrong
470 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
471 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
473 if auth_level not in { 473 ↛ 480line 473 didn't jump to line 480 because the condition on line 473 was never true
474 annotations_pb2.AUTH_LEVEL_OPEN,
475 annotations_pb2.AUTH_LEVEL_JAILED,
476 annotations_pb2.AUTH_LEVEL_SECURE,
477 annotations_pb2.AUTH_LEVEL_EDITOR,
478 annotations_pb2.AUTH_LEVEL_ADMIN,
479 }:
480 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
483def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None:
484 if not auth_info:
485 # if this isn't an open service, fail
486 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
487 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
488 else:
489 # a valid user session was found - check permissions
490 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
491 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
493 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
494 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
496 # if the user is jailed and this isn't an open or jailed service, fail
497 if auth_info.is_jailed and auth_level not in [
498 annotations_pb2.AUTH_LEVEL_OPEN,
499 annotations_pb2.AUTH_LEVEL_JAILED,
500 ]:
501 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
504class MediaInterceptor(grpc.ServerInterceptor):
505 """
506 Extracts an "Authorization: Bearer <hex>" header and calls the
507 is_authorized function. Terminates the call with an HTTP error
508 code if not authorized.
510 Also adds a session to called APIs.
511 """
513 def __init__(self, is_authorized: Callable[[str], bool]):
514 self._is_authorized = is_authorized
516 def intercept_service[T, R](
517 self,
518 continuation: Cont[T, R],
519 handler_call_details: grpc.HandlerCallDetails,
520 ) -> grpc.RpcMethodHandler[T, R]:
521 handler = continuation(handler_call_details)
522 if not handler: 522 ↛ 523line 522 didn't jump to line 523 because the condition on line 522 was never true
523 raise RuntimeError("No handler")
525 prev_func = handler.unary_unary
526 if not prev_func: 526 ↛ 527line 526 didn't jump to line 527 because the condition on line 526 was never true
527 raise RuntimeError(f"No prev_function, {handler}")
529 metadata = dict(handler_call_details.invocation_metadata)
531 token = parse_api_key(metadata)
533 if not token or not self._is_authorized(token): 533 ↛ 534line 533 didn't jump to line 534 because the condition on line 533 was never true
534 return unauthenticated_handler()
536 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
537 with session_scope() as session:
538 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
540 return grpc.unary_unary_rpc_method_handler(
541 function_without_session,
542 request_deserializer=handler.request_deserializer,
543 response_serializer=handler.response_serializer,
544 )
547class OTelInterceptor(grpc.ServerInterceptor):
548 """
549 OpenTelemetry tracing
550 """
552 def __init__(self) -> None:
553 self.tracer = trace.get_tracer(__name__)
555 def intercept_service[T, R](
556 self,
557 continuation: Cont[T, R],
558 handler_call_details: grpc.HandlerCallDetails,
559 ) -> grpc.RpcMethodHandler[T, R]:
560 handler = continuation(handler_call_details)
561 if not handler:
562 raise RuntimeError("No handler")
564 prev_func = handler.unary_unary
565 if not prev_func:
566 raise RuntimeError(f"No prev_function, {handler}")
568 method = handler_call_details.method
570 # method is of the form "/org.couchers.api.core.API/GetUser"
571 _, service_name, method_name = method.split("/")
573 headers = dict(handler_call_details.invocation_metadata)
575 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
576 with self.tracer.start_as_current_span("handler") as rollspan:
577 rollspan.set_attribute("rpc.method_full", method)
578 rollspan.set_attribute("rpc.service", service_name)
579 rollspan.set_attribute("rpc.method", method_name)
581 rollspan.set_attribute("rpc.thread", get_ident())
582 rollspan.set_attribute("rpc.pid", getpid())
584 res = prev_func(request, context)
586 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
587 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
589 return res
591 return grpc.unary_unary_rpc_method_handler(
592 tracing_function,
593 request_deserializer=handler.request_deserializer,
594 response_serializer=handler.response_serializer,
595 )
598class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
599 """
600 If the call resulted in a non-gRPC error, this strips away the error details.
602 It's important to put this first, so that it does not interfere with other interceptors.
603 """
605 def intercept_service[T, R](
606 self,
607 continuation: Cont[T, R],
608 handler_call_details: grpc.HandlerCallDetails,
609 ) -> grpc.RpcMethodHandler[T, R]:
610 handler = continuation(handler_call_details)
611 if not handler: 611 ↛ 612line 611 didn't jump to line 612 because the condition on line 611 was never true
612 raise RuntimeError("No handler")
614 prev_func = handler.unary_unary
615 if not prev_func: 615 ↛ 616line 615 didn't jump to line 616 because the condition on line 615 was never true
616 raise RuntimeError(f"No prev_function, {handler}")
618 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
619 try:
620 res = prev_func(req, context)
621 except Exception as e:
622 code = context.code() # type: ignore[attr-defined]
623 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
624 if not code:
625 logger.exception(e)
626 logger.info("Probably an unknown error! Sanitizing...")
627 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
628 else:
629 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
630 raise e
631 return res
633 return grpc.unary_unary_rpc_method_handler(
634 sanitizing_function,
635 request_deserializer=handler.request_deserializer,
636 response_serializer=handler.response_serializer,
637 )