Coverage for app / backend / src / couchers / interceptors.py: 86%
262 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
1import logging
2from collections.abc import Callable, Mapping
3from copy import deepcopy
4from dataclasses import dataclass, field
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, NoReturn, cast, overload
12import grpc
13import sentry_sdk
14from google.protobuf.descriptor import ServiceDescriptor
15from google.protobuf.descriptor_pool import DescriptorPool
16from google.protobuf.message import Message
17from opentelemetry import trace
18from sqlalchemy import Function, select
19from sqlalchemy.sql import and_, func
21from couchers.constants import (
22 CALL_CANCELLED_ERROR_MESSAGE,
23 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
24 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
25 NONEXISTENT_API_CALL_ERROR_MESSAGE,
26 PERMISSION_DENIED_ERROR_MESSAGE,
27 UNAUTHORIZED_ERROR_MESSAGE,
28 UNKNOWN_ERROR_MESSAGE,
29)
30from couchers.context import CouchersContext, make_interactive_context, make_media_context
31from couchers.db import session_scope
32from couchers.descriptor_pool import get_descriptor_pool
33from couchers.metrics import observe_in_servicer_duration_histogram
34from couchers.models import APICall, User, UserActivity, UserSession
35from couchers.proto import annotations_pb2
36from couchers.proto.annotations_pb2 import AuthLevel
37from couchers.utils import (
38 create_lang_cookie,
39 create_session_cookies,
40 generate_sofa_cookie,
41 now,
42 parse_api_key,
43 parse_session_cookie,
44 parse_sofa_cookie,
45 parse_ui_lang_cookie,
46 parse_user_id_cookie,
47)
49logger = logging.getLogger(__name__)
52@dataclass(frozen=True, slots=True, kw_only=True)
53class UserAuthInfo:
54 """Information about an authenticated user session."""
56 user_id: int
57 is_jailed: bool
58 is_editor: bool
59 is_superuser: bool
60 token_expiry: datetime
61 ui_language_preference: str | None
62 token: str = field(repr=False)
63 is_api_key: bool
66def _binned_now() -> Function[Any]:
67 return func.date_bin("1 hour", func.now(), "2000-01-01")
70def _try_get_and_update_user_details(
71 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
72) -> UserAuthInfo | None:
73 """
74 Tries to get session and user info corresponding to this token.
76 Also updates the user's last active time, token last active time, and increments API call count.
78 Returns UserAuthInfo if a valid session is found, None otherwise.
79 """
80 if not token:
81 return None
83 with session_scope() as session:
84 result = session.execute(
85 select(User, UserSession, UserActivity)
86 .select_from(UserSession)
87 .join(User, User.id == UserSession.user_id)
88 .outerjoin(
89 UserActivity,
90 and_(
91 UserActivity.user_id == User.id,
92 UserActivity.period == _binned_now(),
93 UserActivity.ip_address == ip_address,
94 UserActivity.user_agent == user_agent,
95 ),
96 )
97 .where(User.is_visible)
98 .where(UserSession.token == token)
99 .where(UserSession.is_valid)
100 .where(UserSession.is_api_key == is_api_key)
101 ).one_or_none()
103 if not result:
104 return None
106 user, user_session, user_activity = result._tuple()
108 # update user last active time if it's been a while
109 if now() - user.last_active > timedelta(minutes=5):
110 user.last_active = func.now()
112 # let's update the token
113 user_session.last_seen = func.now()
114 user_session.api_calls += 1
116 if user_activity:
117 user_activity.api_calls += 1
118 else:
119 session.add(
120 UserActivity(
121 user_id=user.id,
122 period=_binned_now(),
123 ip_address=ip_address,
124 user_agent=user_agent,
125 api_calls=1,
126 )
127 )
129 session.commit()
131 return UserAuthInfo(
132 user_id=user.id,
133 is_jailed=user.is_jailed,
134 is_editor=user.is_editor,
135 is_superuser=user.is_superuser,
136 token_expiry=user_session.expiry,
137 ui_language_preference=user.ui_language_preference,
138 token=token,
139 is_api_key=is_api_key,
140 )
143def abort_handler[T, R](
144 message: str,
145 status_code: grpc.StatusCode,
146) -> grpc.RpcMethodHandler[T, R]:
147 def f(request: Any, context: CouchersContext) -> NoReturn:
148 context.abort(status_code, message)
150 return grpc.unary_unary_rpc_method_handler(f)
153def unauthenticated_handler[T, R](
154 message: str = UNAUTHORIZED_ERROR_MESSAGE,
155 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
156) -> grpc.RpcMethodHandler[T, R]:
157 return abort_handler(message, status_code)
160@overload
161def _sanitized_bytes(proto: Message) -> bytes: ...
162@overload
163def _sanitized_bytes(proto: None) -> None: ...
164def _sanitized_bytes(proto: Message | None) -> bytes | None:
165 """
166 Remove fields marked sensitive and return serialized bytes
167 """
168 if not proto:
169 return None
171 new_proto = deepcopy(proto)
173 def _sanitize_message(message: Message) -> None:
174 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
175 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
176 message.ClearField(name)
177 if descriptor.message_type:
178 submessage = getattr(message, name)
179 if not submessage:
180 continue
181 if descriptor.is_repeated:
182 for msg in submessage:
183 _sanitize_message(msg)
184 else:
185 _sanitize_message(submessage)
187 _sanitize_message(new_proto)
189 return new_proto.SerializeToString()
192def _store_log(
193 *,
194 method: str,
195 status_code: str | None = None,
196 duration: float,
197 user_id: int | None,
198 is_api_key: bool,
199 request: Message,
200 response: Message | None,
201 traceback: str | None = None,
202 perf_report: str | None = None,
203 ip_address: str | None,
204 user_agent: str | None,
205 sofa: str | None,
206) -> None:
207 req_bytes = _sanitized_bytes(request)
208 res_bytes = _sanitized_bytes(response)
209 with session_scope() as session:
210 response_truncated = False
211 truncate_res_bytes_length = 16 * 1024 # 16 kB
212 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 res_bytes = res_bytes[:truncate_res_bytes_length]
214 response_truncated = True
215 session.add(
216 APICall(
217 is_api_key=is_api_key,
218 method=method,
219 status_code=status_code,
220 duration=duration,
221 user_id=user_id,
222 request=req_bytes,
223 response=res_bytes,
224 response_truncated=response_truncated,
225 traceback=traceback,
226 perf_report=perf_report,
227 ip_address=ip_address,
228 user_agent=user_agent,
229 sofa=sofa,
230 )
231 )
232 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
235type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
238class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
239 """
240 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
242 Sets context.user_id and context.token if authenticated, otherwise
243 terminates the call with an UNAUTHENTICATED error code.
245 2. Makes sure cookies are in sync.
247 3. Injects a session to get a database transaction.
249 4. Measures and logs the time it takes to service each incoming call.
250 """
252 def __init__(self) -> None:
253 self._pool = get_descriptor_pool()
255 def intercept_service[T = Message, R = Message](
256 self,
257 continuation: Cont[T, R],
258 handler_call_details: grpc.HandlerCallDetails,
259 ) -> grpc.RpcMethodHandler[T, R]:
260 start = perf_counter_ns()
262 method = handler_call_details.method
264 try:
265 auth_level = find_auth_level(self._pool, method)
266 except AbortError as ae:
267 return abort_handler(ae.msg, ae.code)
269 try:
270 headers = parse_headers(dict(handler_call_details.invocation_metadata))
271 except BadHeaders:
272 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
274 auth_info = _try_get_and_update_user_details(
275 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent
276 )
278 try:
279 check_permissions(auth_info, auth_level)
280 except AbortError as ae:
281 return unauthenticated_handler(ae.msg, ae.code)
283 if not (handler := continuation(handler_call_details)): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 raise RuntimeError(f"No handler in '{method}'")
286 if not (prev_function := handler.unary_unary): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 raise RuntimeError(f"No prev_function in '{method}', {handler}")
289 if headers.sofa:
290 sofa = headers.sofa
291 new_sofa_cookie = None
292 else:
293 sofa, new_sofa_cookie = generate_sofa_cookie()
295 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
296 couchers_context = make_interactive_context(
297 grpc_context=grpc_context,
298 user_id=auth_info.user_id if auth_info else None,
299 is_api_key=auth_info.is_api_key if auth_info else False,
300 token=auth_info.token if auth_info else None,
301 ui_language_preference=(auth_info.ui_language_preference if auth_info else None) or headers.ui_lang,
302 )
304 with session_scope() as session:
305 try:
306 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
307 res = cast(Message, _res)
308 finished = perf_counter_ns()
309 duration = (finished - start) / 1e6 # ms
310 _store_log(
311 method=method,
312 duration=duration,
313 user_id=couchers_context._user_id,
314 is_api_key=cast(bool, couchers_context._is_api_key),
315 request=req,
316 response=res,
317 ip_address=headers.ip_address,
318 user_agent=headers.user_agent,
319 sofa=sofa,
320 )
321 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
322 except Exception as e:
323 finished = perf_counter_ns()
324 duration = (finished - start) / 1e6 # ms
326 if couchers_context._grpc_context: 326 ↛ 330line 326 didn't jump to line 330 because the condition on line 326 was always true
327 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined]
328 code = getattr(context_code, "name", None)
329 else:
330 code = None
332 traceback = "".join(format_exception(type(e), e, e.__traceback__))
333 _store_log(
334 method=method,
335 status_code=code,
336 duration=duration,
337 user_id=couchers_context._user_id,
338 is_api_key=cast(bool, couchers_context._is_api_key),
339 request=req,
340 response=None,
341 traceback=traceback,
342 ip_address=headers.ip_address,
343 user_agent=headers.user_agent,
344 sofa=sofa,
345 )
346 observe_in_servicer_duration_histogram(
347 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
348 )
350 if not code:
351 sentry_sdk.set_tag("context", "servicer")
352 sentry_sdk.set_tag("method", method)
353 sentry_sdk.capture_exception(e)
355 raise e
357 if auth_info and not auth_info.is_api_key:
358 # check the two cookies are in sync & that language preference cookie is correct
359 if headers.user_id != str(auth_info.user_id): 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true
360 couchers_context.set_cookies(
361 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry)
362 )
363 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang:
364 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
366 if new_sofa_cookie:
367 couchers_context.set_cookies([new_sofa_cookie])
369 if not grpc_context.is_active(): 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true
370 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
372 couchers_context._send_cookies()
374 return res
376 return grpc.unary_unary_rpc_method_handler(
377 function_without_couchers_stuff,
378 request_deserializer=handler.request_deserializer,
379 response_serializer=handler.response_serializer,
380 )
383@dataclass(frozen=True, slots=True, kw_only=True)
384class CouchersHeaders:
385 token: str | None = field(repr=False)
386 is_api_key: bool
387 ip_address: str | None
388 user_agent: str | None
389 ui_lang: str | None
390 user_id: str | None
391 sofa: str | None
394def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders:
395 if "cookie" in headers and "authorization" in headers:
396 # for security reasons, only one of "cookie" or "authorization" can be present
397 raise BadHeaders("Both cookies and authorization are present in headers")
398 elif "cookie" in headers:
399 # the session token is passed in cookies, i.e., in the `cookie` header
400 token, is_api_key = parse_session_cookie(headers), False
401 elif "authorization" in headers:
402 # the session token is passed in the `authorization` header
403 token, is_api_key = parse_api_key(headers), True
404 else:
405 # no session found
406 token, is_api_key = None, False
408 ip_address = headers.get("x-couchers-real-ip")
409 user_agent = headers.get("user-agent")
411 ui_lang = parse_ui_lang_cookie(headers)
412 user_id = parse_user_id_cookie(headers)
413 sofa = parse_sofa_cookie(headers)
415 return CouchersHeaders(
416 token=token,
417 is_api_key=is_api_key,
418 ip_address=ip_address if isinstance(ip_address, str) else None,
419 user_agent=user_agent if isinstance(user_agent, str) else None,
420 ui_lang=ui_lang,
421 user_id=user_id,
422 sofa=sofa,
423 )
426class BadHeaders(Exception):
427 pass
430class AbortError(Exception):
431 def __init__(self, msg: str, code: grpc.StatusCode):
432 self.msg = msg
433 self.code = code
436def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType:
437 # method is of the form "/org.couchers.api.core.API/GetUser"
438 _, service_name, method_name = method.split("/")
440 try:
441 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
442 service_options = service.GetOptions()
443 except KeyError:
444 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None
446 level = service_options.Extensions[annotations_pb2.auth_level]
448 validate_auth_level(level)
450 return level
453def validate_auth_level(auth_level: AuthLevel.ValueType) -> None:
454 # if unknown auth level, then it wasn't set and something's wrong
455 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
456 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
458 if auth_level not in { 458 ↛ 465line 458 didn't jump to line 465 because the condition on line 458 was never true
459 annotations_pb2.AUTH_LEVEL_OPEN,
460 annotations_pb2.AUTH_LEVEL_JAILED,
461 annotations_pb2.AUTH_LEVEL_SECURE,
462 annotations_pb2.AUTH_LEVEL_EDITOR,
463 annotations_pb2.AUTH_LEVEL_ADMIN,
464 }:
465 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
468def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None:
469 if not auth_info:
470 # if this isn't an open service, fail
471 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
472 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
473 else:
474 # a valid user session was found - check permissions
475 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
476 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
478 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
479 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
481 # if the user is jailed and this isn't an open or jailed service, fail
482 if auth_info.is_jailed and auth_level not in [
483 annotations_pb2.AUTH_LEVEL_OPEN,
484 annotations_pb2.AUTH_LEVEL_JAILED,
485 ]:
486 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
489class MediaInterceptor(grpc.ServerInterceptor):
490 """
491 Extracts an "Authorization: Bearer <hex>" header and calls the
492 is_authorized function. Terminates the call with an HTTP error
493 code if not authorized.
495 Also adds a session to called APIs.
496 """
498 def __init__(self, is_authorized: Callable[[str], bool]):
499 self._is_authorized = is_authorized
501 def intercept_service[T, R](
502 self,
503 continuation: Cont[T, R],
504 handler_call_details: grpc.HandlerCallDetails,
505 ) -> grpc.RpcMethodHandler[T, R]:
506 handler = continuation(handler_call_details)
507 if not handler: 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true
508 raise RuntimeError("No handler")
510 prev_func = handler.unary_unary
511 if not prev_func: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true
512 raise RuntimeError(f"No prev_function, {handler}")
514 metadata = dict(handler_call_details.invocation_metadata)
516 token = parse_api_key(metadata)
518 if not token or not self._is_authorized(token): 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true
519 return unauthenticated_handler()
521 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
522 with session_scope() as session:
523 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
525 return grpc.unary_unary_rpc_method_handler(
526 function_without_session,
527 request_deserializer=handler.request_deserializer,
528 response_serializer=handler.response_serializer,
529 )
532class OTelInterceptor(grpc.ServerInterceptor):
533 """
534 OpenTelemetry tracing
535 """
537 def __init__(self) -> None:
538 self.tracer = trace.get_tracer(__name__)
540 def intercept_service[T, R](
541 self,
542 continuation: Cont[T, R],
543 handler_call_details: grpc.HandlerCallDetails,
544 ) -> grpc.RpcMethodHandler[T, R]:
545 handler = continuation(handler_call_details)
546 if not handler:
547 raise RuntimeError("No handler")
549 prev_func = handler.unary_unary
550 if not prev_func:
551 raise RuntimeError(f"No prev_function, {handler}")
553 method = handler_call_details.method
555 # method is of the form "/org.couchers.api.core.API/GetUser"
556 _, service_name, method_name = method.split("/")
558 headers = dict(handler_call_details.invocation_metadata)
560 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
561 with self.tracer.start_as_current_span("handler") as rollspan:
562 rollspan.set_attribute("rpc.method_full", method)
563 rollspan.set_attribute("rpc.service", service_name)
564 rollspan.set_attribute("rpc.method", method_name)
566 rollspan.set_attribute("rpc.thread", get_ident())
567 rollspan.set_attribute("rpc.pid", getpid())
569 res = prev_func(request, context)
571 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
572 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
574 return res
576 return grpc.unary_unary_rpc_method_handler(
577 tracing_function,
578 request_deserializer=handler.request_deserializer,
579 response_serializer=handler.response_serializer,
580 )
583class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
584 """
585 If the call resulted in a non-gRPC error, this strips away the error details.
587 It's important to put this first, so that it does not interfere with other interceptors.
588 """
590 def intercept_service[T, R](
591 self,
592 continuation: Cont[T, R],
593 handler_call_details: grpc.HandlerCallDetails,
594 ) -> grpc.RpcMethodHandler[T, R]:
595 handler = continuation(handler_call_details)
596 if not handler: 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true
597 raise RuntimeError("No handler")
599 prev_func = handler.unary_unary
600 if not prev_func: 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true
601 raise RuntimeError(f"No prev_function, {handler}")
603 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
604 try:
605 res = prev_func(req, context)
606 except Exception as e:
607 code = context.code() # type: ignore[attr-defined]
608 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
609 if not code:
610 logger.exception(e)
611 logger.info("Probably an unknown error! Sanitizing...")
612 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
613 else:
614 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
615 raise e
616 return res
618 return grpc.unary_unary_rpc_method_handler(
619 sanitizing_function,
620 request_deserializer=handler.request_deserializer,
621 response_serializer=handler.response_serializer,
622 )