Coverage for app/backend/src/couchers/interceptors.py: 87%
334 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
1import logging
2from collections.abc import Callable, Mapping
3from copy import deepcopy
4from dataclasses import dataclass, field
5from datetime import datetime, timedelta
6from functools import cache
7from os import getpid
8from threading import get_ident
9from time import perf_counter_ns
10from traceback import format_exception
11from typing import Any, NoReturn, cast, overload
12from zoneinfo import ZoneInfo
14import grpc
15import sentry_sdk
16from google.protobuf.descriptor import Descriptor, ServiceDescriptor
17from google.protobuf.descriptor_pool import DescriptorPool
18from google.protobuf.message import Message
19from opentelemetry import trace
20from sqlalchemy import Function, literal_column, select
21from sqlalchemy.dialects.postgresql import insert as pg_insert
22from sqlalchemy.sql import func
24from couchers.config import config
25from couchers.constants import (
26 CALL_CANCELLED_ERROR_MESSAGE,
27 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
28 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
29 NONEXISTENT_API_CALL_ERROR_MESSAGE,
30 PERMISSION_DENIED_ERROR_MESSAGE,
31 UNAUTHORIZED_ERROR_MESSAGE,
32 UNKNOWN_ERROR_MESSAGE,
33)
34from couchers.context import CouchersContext, make_interactive_context, make_media_context
35from couchers.db import session_scope
36from couchers.descriptor_pool import get_descriptor_pool
37from couchers.i18n import LocalizationContext
38from couchers.i18n.locales import to_supported_locale
39from couchers.metrics import (
40 observe_api_call,
41 observe_in_servicer_duration_histogram,
42 observe_in_servicer_perf_histograms,
43 observe_in_servicer_pool_wait_histogram,
44 observe_in_servicer_serde_histogram,
45 observe_in_servicer_setup_errors_counter,
46 observe_in_servicer_setup_histogram,
47)
48from couchers.models import APICall, ClientPlatform, User, UserActivity, UserSession
49from couchers.perf import PerfResult, read_perf, start_perf
50from couchers.proto import annotations_pb2
51from couchers.proto.annotations_pb2 import AuthLevel
52from couchers.utils import (
53 create_lang_cookie,
54 create_session_cookies,
55 generate_sofa_cookie,
56 now,
57 parse_api_key,
58 parse_session_cookie,
59 parse_sofa_cookie,
60 parse_ui_lang_cookie,
61 parse_user_id_cookie,
62)
64logger = logging.getLogger(__name__)
67@dataclass(frozen=True, slots=True, kw_only=True)
68class UserAuthInfo:
69 """Information about an authenticated user session."""
71 user_id: int
72 is_jailed: bool
73 is_editor: bool
74 is_superuser: bool
75 token_expiry: datetime
76 ui_language_preference: str | None
77 timezone: str | None
78 token: str = field(repr=False)
79 is_api_key: bool
82def _binned_now() -> Function[Any]:
83 return func.date_bin(
84 literal_column("interval '1 hour'"),
85 func.now(),
86 literal_column("'2000-01-01'::timestamptz"),
87 )
90def _try_get_and_update_user_details(
91 token: str | None,
92 is_api_key: bool,
93 ip_address: str | None,
94 user_agent: str | None,
95 sofa: str | None,
96 client_platform: ClientPlatform | None,
97) -> UserAuthInfo | None:
98 """
99 Tries to get session and user info corresponding to this token.
101 Also updates the user's last active time, token last active time, and increments API call count.
103 Returns UserAuthInfo if a valid session is found, None otherwise.
104 """
105 if not token:
106 return None
108 with session_scope() as session:
109 result = session.execute(
110 select(User, UserSession, User.is_jailed)
111 .select_from(UserSession)
112 .join(User, User.id == UserSession.user_id)
113 .where(User.is_visible)
114 .where(UserSession.token == token)
115 .where(UserSession.is_valid)
116 .where(UserSession.is_api_key == is_api_key)
117 ).one_or_none()
119 if not result:
120 return None
122 user, user_session, is_jailed = result._tuple()
124 # update user last active time if it's been a while
125 if now() - user.last_active > timedelta(minutes=5):
126 user.last_active = func.now()
128 # let's update the token
129 user_session.last_seen = func.now()
130 user_session.api_calls += 1
132 # upsert so concurrent requests for the same activity tuple don't race to insert and violate the index
133 insert_stmt = pg_insert(UserActivity).values(
134 user_id=user.id,
135 period=_binned_now(),
136 ip_address=ip_address,
137 user_agent=user_agent,
138 sofa=sofa,
139 client_platform=client_platform,
140 api_calls=1,
141 )
142 session.execute(
143 insert_stmt.on_conflict_do_update(
144 index_elements=[
145 UserActivity.user_id,
146 UserActivity.period,
147 UserActivity.ip_address,
148 UserActivity.user_agent,
149 UserActivity.sofa,
150 ],
151 set_={
152 "api_calls": UserActivity.api_calls + 1,
153 "client_platform": func.coalesce(
154 insert_stmt.excluded.client_platform, UserActivity.client_platform
155 ),
156 },
157 )
158 )
160 # build before committing to avoid expire_on_commit reloading these attributes
161 auth_info = UserAuthInfo(
162 user_id=user.id,
163 is_jailed=is_jailed,
164 is_editor=user.is_editor,
165 is_superuser=user.is_superuser,
166 token_expiry=user_session.expiry,
167 ui_language_preference=user.ui_language_preference,
168 timezone=user.timezone,
169 token=token,
170 is_api_key=is_api_key,
171 )
173 session.commit()
175 return auth_info
178def abort_handler[T, R](
179 message: str,
180 status_code: grpc.StatusCode,
181) -> grpc.RpcMethodHandler[T, R]:
182 def f(request: Any, context: CouchersContext) -> NoReturn:
183 context.abort(status_code, message)
185 return grpc.unary_unary_rpc_method_handler(f)
188def unauthenticated_handler[T, R](
189 message: str = UNAUTHORIZED_ERROR_MESSAGE,
190 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
191) -> grpc.RpcMethodHandler[T, R]:
192 return abort_handler(message, status_code)
195@cache
196def _descriptor_has_sensitive(descriptor: Descriptor) -> bool:
197 """Whether this message type transitively contains any field marked sensitive."""
198 seen: set[Descriptor] = set()
199 stack = [descriptor]
200 while stack:
201 d = stack.pop()
202 if d in seen:
203 continue
204 seen.add(d)
205 for f in d.fields:
206 if f.GetOptions().Extensions[annotations_pb2.sensitive]:
207 return True
208 if f.message_type is not None:
209 stack.append(f.message_type)
210 return False
213@dataclass(frozen=True, slots=True)
214class _SanitizePlan:
215 fields_to_clear: tuple[str, ...]
216 fields_to_recurse: tuple[tuple[str, bool], ...] # (field name, is_repeated)
219@cache
220def _sanitize_plan(descriptor: Descriptor) -> _SanitizePlan:
221 """For a message type, the fields to clear and the subfields worth recursing into."""
222 clear = []
223 recurse = []
224 for f in descriptor.fields:
225 if f.GetOptions().Extensions[annotations_pb2.sensitive]:
226 clear.append(f.name)
227 elif f.message_type is not None and _descriptor_has_sensitive(f.message_type):
228 recurse.append((f.name, f.is_repeated))
229 return _SanitizePlan(fields_to_clear=tuple(clear), fields_to_recurse=tuple(recurse))
232def _sanitize_message(message: Message) -> None:
233 plan = _sanitize_plan(message.DESCRIPTOR)
234 for name in plan.fields_to_clear:
235 message.ClearField(name)
236 for name, is_repeated in plan.fields_to_recurse:
237 submessage = getattr(message, name)
238 if not submessage: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true
239 continue
240 if is_repeated: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true
241 for msg in submessage:
242 _sanitize_message(msg)
243 else:
244 _sanitize_message(submessage)
247@overload
248def _sanitized_bytes(proto: Message) -> bytes: ...
249@overload
250def _sanitized_bytes(proto: None) -> None: ...
251def _sanitized_bytes(proto: Message | None) -> bytes | None:
252 """
253 Remove fields marked sensitive and return serialized bytes.
255 Sensitivity is static per message type, so the descriptor analysis is cached: messages whose type has no
256 sensitive field anywhere serialize directly without a copy or walk.
257 """
258 if not proto:
259 return None
261 if not _descriptor_has_sensitive(proto.DESCRIPTOR):
262 return proto.SerializeToString()
264 new_proto = deepcopy(proto)
265 _sanitize_message(new_proto)
266 return new_proto.SerializeToString()
269def _store_log(
270 *,
271 method: str,
272 status_code: str | None = None,
273 duration: float,
274 user_id: int | None,
275 is_api_key: bool,
276 request: Message,
277 response: Message | None,
278 traceback: str | None = None,
279 perf_report: str | None = None,
280 perf: PerfResult | None = None,
281 client_platform: ClientPlatform | None = None,
282 ip_address: str | None,
283 user_agent: str | None,
284 sofa: str | None,
285) -> None:
286 req_bytes = _sanitized_bytes(request)
287 res_bytes = _sanitized_bytes(response)
288 with session_scope() as session:
289 response_truncated = False
290 truncate_res_bytes_length = 16 * 1024 # 16 kB
291 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 res_bytes = res_bytes[:truncate_res_bytes_length]
293 response_truncated = True
294 session.add(
295 APICall(
296 is_api_key=is_api_key,
297 method=method,
298 status_code=status_code,
299 duration=duration,
300 user_id=user_id,
301 request=req_bytes,
302 response=res_bytes,
303 response_truncated=response_truncated,
304 traceback=traceback,
305 perf_report=perf_report,
306 db_query_count=perf.db_query_count if perf else None,
307 db_write_query_count=perf.db_write_query_count if perf else None,
308 db_time_ms=perf.db_time_ms if perf else None,
309 cpu_ms=perf.cpu_ms if perf else None,
310 client_platform=client_platform,
311 ip_address=ip_address,
312 user_agent=user_agent,
313 sofa=sofa,
314 )
315 )
316 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
319type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
322class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
323 """
324 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
326 Sets context.user_id and context.token if authenticated, otherwise
327 terminates the call with an UNAUTHENTICATED error code.
329 2. Makes sure cookies are in sync.
331 3. Injects a session to get a database transaction.
333 4. Measures and logs the time it takes to service each incoming call.
334 """
336 def __init__(self) -> None:
337 self._pool = get_descriptor_pool()
339 def intercept_service[T = Message, R = Message](
340 self,
341 continuation: Cont[T, R],
342 handler_call_details: grpc.HandlerCallDetails,
343 ) -> grpc.RpcMethodHandler[T, R]:
344 start = perf_counter_ns()
346 method = handler_call_details.method
348 # accounting for the auth/setup phase; the handler re-arms its own below
349 start_perf()
351 try:
352 try:
353 auth_level = find_auth_level(self._pool, method)
354 except AbortError as ae:
355 return abort_handler(ae.msg, ae.code)
357 try:
358 headers = parse_headers(dict(handler_call_details.invocation_metadata))
359 except BadHeaders:
360 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
362 # if this is not present in prod, it's a Big Bug in config
363 assert config.DEV or headers.ip_address is not None
365 auth_info = _try_get_and_update_user_details(
366 headers.token,
367 headers.is_api_key,
368 headers.ip_address,
369 headers.user_agent,
370 headers.sofa,
371 headers.client_platform,
372 )
374 try:
375 check_permissions(auth_info, auth_level)
376 except AbortError as ae:
377 return unauthenticated_handler(ae.msg, ae.code)
379 if not (handler := continuation(handler_call_details)): 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true
380 raise RuntimeError(f"No handler in '{method}'")
382 if not (prev_function := handler.unary_unary): 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true
383 raise RuntimeError(f"No prev_function in '{method}', {handler}")
385 if headers.sofa:
386 sofa = headers.sofa
387 new_sofa_cookie = None
388 else:
389 sofa, new_sofa_cookie = generate_sofa_cookie()
391 locale = to_supported_locale((auth_info.ui_language_preference if auth_info else headers.ui_lang) or "")
392 loc_context = LocalizationContext(
393 locale=locale,
394 timezone=ZoneInfo((auth_info and auth_info.timezone) or "Etc/UTC"),
395 )
397 observe_in_servicer_setup_histogram(method, read_perf())
398 except Exception as e:
399 observe_in_servicer_setup_errors_counter(method, type(e).__name__)
400 sentry_sdk.set_tag("context", "servicer_setup")
401 sentry_sdk.set_tag("method", method)
402 sentry_sdk.capture_exception(e)
403 return abort_handler(UNKNOWN_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
405 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
406 couchers_context = make_interactive_context(
407 grpc_context=grpc_context,
408 user_id=auth_info.user_id if auth_info else None,
409 is_api_key=auth_info.is_api_key if auth_info else False,
410 token=auth_info.token if auth_info else None,
411 localization=loc_context,
412 sofa=sofa,
413 )
415 with session_scope() as session:
416 # force the checkout now so its wait is timed here rather than hiding in the handler's first query
417 pool_wait_start = perf_counter_ns()
418 session.connection()
419 observe_in_servicer_pool_wait_histogram(method, (perf_counter_ns() - pool_wait_start) / 1e9)
420 start_perf()
421 try:
422 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
423 res = cast(Message, _res)
424 # flush so pending ORM writes execute (and are counted) before we snapshot; a handler that only
425 # session.add(...)s and returns would otherwise flush at commit, after read_perf()
426 session.flush()
427 perf = read_perf()
428 finished = perf_counter_ns()
429 duration = (finished - start) / 1e6 # ms
430 _store_log(
431 method=method,
432 duration=duration,
433 user_id=couchers_context._user_id,
434 is_api_key=cast(bool, couchers_context._is_api_key),
435 request=req,
436 response=res,
437 perf=perf,
438 client_platform=headers.client_platform,
439 ip_address=headers.ip_address,
440 user_agent=headers.user_agent,
441 sofa=sofa,
442 )
443 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
444 observe_api_call(method, headers.client_platform)
445 observe_in_servicer_perf_histograms(method, perf)
446 except Exception as e:
447 perf = read_perf()
448 finished = perf_counter_ns()
449 duration = (finished - start) / 1e6 # ms
451 if couchers_context._grpc_context: 451 ↛ 455line 451 didn't jump to line 455 because the condition on line 451 was always true
452 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined]
453 code = getattr(context_code, "name", None)
454 else:
455 code = None
457 traceback = "".join(format_exception(type(e), e, e.__traceback__))
458 _store_log(
459 method=method,
460 status_code=code,
461 duration=duration,
462 user_id=couchers_context._user_id,
463 is_api_key=cast(bool, couchers_context._is_api_key),
464 request=req,
465 response=None,
466 traceback=traceback,
467 perf=perf,
468 client_platform=headers.client_platform,
469 ip_address=headers.ip_address,
470 user_agent=headers.user_agent,
471 sofa=sofa,
472 )
473 observe_in_servicer_duration_histogram(
474 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
475 )
476 observe_api_call(method, headers.client_platform)
477 observe_in_servicer_perf_histograms(method, perf)
479 if not code:
480 sentry_sdk.set_tag("context", "servicer")
481 sentry_sdk.set_tag("method", method)
482 sentry_sdk.set_tag("user_agent", headers.user_agent)
483 sentry_sdk.set_tag("ui_lang", loc_context.locale)
484 sentry_sdk.set_user(
485 {
486 "id": couchers_context._user_id,
487 "ip_address": headers.ip_address,
488 "sofa": sofa[:12],
489 }
490 )
491 sentry_sdk.capture_exception(e)
493 raise e
495 if auth_info and not auth_info.is_api_key:
496 # check the two cookies are in sync & that language preference cookie is correct
497 if headers.user_id != str(auth_info.user_id): 497 ↛ 501line 497 didn't jump to line 501 because the condition on line 497 was always true
498 couchers_context.set_cookies(
499 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry)
500 )
501 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang:
502 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
504 if new_sofa_cookie:
505 couchers_context.set_cookies([new_sofa_cookie])
507 if not grpc_context.is_active(): 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true
508 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
510 couchers_context._send_cookies()
512 return res
514 def timed_serde[A, B](fn: Callable[[A], B], direction: str) -> Callable[[A], B]:
515 def wrapped(arg: A) -> B:
516 t0 = perf_counter_ns()
517 result = fn(arg)
518 observe_in_servicer_serde_histogram(method, direction, (perf_counter_ns() - t0) / 1e9)
519 return result
521 return wrapped
523 # always set for our generated-proto methods, but grpc types them as optional
524 assert handler.request_deserializer is not None and handler.response_serializer is not None
525 return grpc.unary_unary_rpc_method_handler(
526 function_without_couchers_stuff,
527 request_deserializer=timed_serde(handler.request_deserializer, "deserialize"),
528 response_serializer=timed_serde(handler.response_serializer, "serialize"),
529 )
532@dataclass(frozen=True, slots=True, kw_only=True)
533class CouchersHeaders:
534 token: str | None = field(repr=False)
535 is_api_key: bool
536 ip_address: str | None
537 user_agent: str | None
538 client_platform: ClientPlatform | None
539 ui_lang: str | None
540 user_id: str | None
541 sofa: str | None
544def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders:
545 if "cookie" in headers and "authorization" in headers:
546 # for security reasons, only one of "cookie" or "authorization" can be present
547 raise BadHeaders("Both cookies and authorization are present in headers")
548 elif "cookie" in headers:
549 # the session token is passed in cookies, i.e., in the `cookie` header
550 token, is_api_key = parse_session_cookie(headers), False
551 elif "authorization" in headers:
552 # the session token is passed in the `authorization` header
553 token, is_api_key = parse_api_key(headers), True
554 else:
555 # no session found
556 token, is_api_key = None, False
558 ip_address = headers.get("x-couchers-real-ip")
559 user_agent = headers.get("user-agent")
561 # the client (web app or native app) declares its platform via this header
562 client_platform_raw = headers.get("x-couchers-client-platform")
563 client_platform = (
564 ClientPlatform[client_platform_raw]
565 if isinstance(client_platform_raw, str) and client_platform_raw in ClientPlatform.__members__
566 else None
567 )
569 ui_lang = parse_ui_lang_cookie(headers)
570 user_id = parse_user_id_cookie(headers)
571 sofa = parse_sofa_cookie(headers)
573 return CouchersHeaders(
574 token=token,
575 is_api_key=is_api_key,
576 ip_address=ip_address if isinstance(ip_address, str) else None,
577 user_agent=user_agent if isinstance(user_agent, str) else None,
578 client_platform=client_platform,
579 ui_lang=ui_lang,
580 user_id=user_id,
581 sofa=sofa,
582 )
585class BadHeaders(Exception):
586 pass
589class AbortError(Exception):
590 def __init__(self, msg: str, code: grpc.StatusCode):
591 self.msg = msg
592 self.code = code
595def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType:
596 # method is of the form "/org.couchers.api.core.API/GetUser"
597 _, service_name, method_name = method.split("/")
599 try:
600 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
601 service_options = service.GetOptions()
602 except KeyError:
603 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None
605 level = service_options.Extensions[annotations_pb2.auth_level]
607 validate_auth_level(level)
609 return level
612def validate_auth_level(auth_level: AuthLevel.ValueType) -> None:
613 # if unknown auth level, then it wasn't set and something's wrong
614 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
615 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
617 if auth_level not in { 617 ↛ 624line 617 didn't jump to line 624 because the condition on line 617 was never true
618 annotations_pb2.AUTH_LEVEL_OPEN,
619 annotations_pb2.AUTH_LEVEL_JAILED,
620 annotations_pb2.AUTH_LEVEL_SECURE,
621 annotations_pb2.AUTH_LEVEL_EDITOR,
622 annotations_pb2.AUTH_LEVEL_ADMIN,
623 }:
624 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
627def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None:
628 if not auth_info:
629 # if this isn't an open service, fail
630 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
631 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
632 else:
633 # a valid user session was found - check permissions
634 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
635 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
637 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
638 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
640 # if the user is jailed and this isn't an open or jailed service, fail
641 if auth_info.is_jailed and auth_level not in [
642 annotations_pb2.AUTH_LEVEL_OPEN,
643 annotations_pb2.AUTH_LEVEL_JAILED,
644 ]:
645 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
648class MediaInterceptor(grpc.ServerInterceptor):
649 """
650 Extracts an "Authorization: Bearer <hex>" header and calls the
651 is_authorized function. Terminates the call with an HTTP error
652 code if not authorized.
654 Also adds a session to called APIs.
655 """
657 def __init__(self, is_authorized: Callable[[str], bool]):
658 self._is_authorized = is_authorized
660 def intercept_service[T, R](
661 self,
662 continuation: Cont[T, R],
663 handler_call_details: grpc.HandlerCallDetails,
664 ) -> grpc.RpcMethodHandler[T, R]:
665 handler = continuation(handler_call_details)
666 if not handler: 666 ↛ 667line 666 didn't jump to line 667 because the condition on line 666 was never true
667 raise RuntimeError("No handler")
669 prev_func = handler.unary_unary
670 if not prev_func: 670 ↛ 671line 670 didn't jump to line 671 because the condition on line 670 was never true
671 raise RuntimeError(f"No prev_function, {handler}")
673 metadata = dict(handler_call_details.invocation_metadata)
675 token = parse_api_key(metadata)
677 if not token or not self._is_authorized(token): 677 ↛ 678line 677 didn't jump to line 678 because the condition on line 677 was never true
678 return unauthenticated_handler()
680 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
681 with session_scope() as session:
682 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
684 return grpc.unary_unary_rpc_method_handler(
685 function_without_session,
686 request_deserializer=handler.request_deserializer,
687 response_serializer=handler.response_serializer,
688 )
691class OTelInterceptor(grpc.ServerInterceptor):
692 """
693 OpenTelemetry tracing
694 """
696 def __init__(self) -> None:
697 self.tracer = trace.get_tracer(__name__)
699 def intercept_service[T, R](
700 self,
701 continuation: Cont[T, R],
702 handler_call_details: grpc.HandlerCallDetails,
703 ) -> grpc.RpcMethodHandler[T, R]:
704 handler = continuation(handler_call_details)
705 if not handler:
706 raise RuntimeError("No handler")
708 prev_func = handler.unary_unary
709 if not prev_func:
710 raise RuntimeError(f"No prev_function, {handler}")
712 method = handler_call_details.method
714 # method is of the form "/org.couchers.api.core.API/GetUser"
715 _, service_name, method_name = method.split("/")
717 headers = dict(handler_call_details.invocation_metadata)
719 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
720 with self.tracer.start_as_current_span("handler") as rollspan:
721 rollspan.set_attribute("rpc.method_full", method)
722 rollspan.set_attribute("rpc.service", service_name)
723 rollspan.set_attribute("rpc.method", method_name)
725 rollspan.set_attribute("rpc.thread", get_ident())
726 rollspan.set_attribute("rpc.pid", getpid())
728 res = prev_func(request, context)
730 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
731 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
733 return res
735 return grpc.unary_unary_rpc_method_handler(
736 tracing_function,
737 request_deserializer=handler.request_deserializer,
738 response_serializer=handler.response_serializer,
739 )
742class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
743 """
744 If the call resulted in a non-gRPC error, this strips away the error details.
746 It's important to put this first, so that it does not interfere with other interceptors.
747 """
749 def intercept_service[T, R](
750 self,
751 continuation: Cont[T, R],
752 handler_call_details: grpc.HandlerCallDetails,
753 ) -> grpc.RpcMethodHandler[T, R]:
754 handler = continuation(handler_call_details)
755 if not handler: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true
756 raise RuntimeError("No handler")
758 prev_func = handler.unary_unary
759 if not prev_func: 759 ↛ 760line 759 didn't jump to line 760 because the condition on line 759 was never true
760 raise RuntimeError(f"No prev_function, {handler}")
762 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
763 try:
764 res = prev_func(req, context)
765 except Exception as e:
766 code = context.code() # type: ignore[attr-defined]
767 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
768 if not code:
769 logger.exception(e)
770 logger.info("Probably an unknown error! Sanitizing...")
771 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
772 else:
773 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
774 raise e
775 return res
777 return grpc.unary_unary_rpc_method_handler(
778 sanitizing_function,
779 request_deserializer=handler.request_deserializer,
780 response_serializer=handler.response_serializer,
781 )