Coverage for src / couchers / interceptors.py: 85%
255 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-13 12:05 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-13 12:05 +0000
1import logging
2from collections.abc import Callable, Mapping
3from copy import deepcopy
4from dataclasses import dataclass, field
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, NoReturn, cast, overload
12import grpc
13import sentry_sdk
14from google.protobuf.descriptor import ServiceDescriptor
15from google.protobuf.descriptor_pool import DescriptorPool
16from google.protobuf.message import Message
17from opentelemetry import trace
18from sqlalchemy import Function, select
19from sqlalchemy.sql import and_, func
21from couchers.constants import (
22 CALL_CANCELLED_ERROR_MESSAGE,
23 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
24 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
25 NONEXISTENT_API_CALL_ERROR_MESSAGE,
26 PERMISSION_DENIED_ERROR_MESSAGE,
27 UNAUTHORIZED_ERROR_MESSAGE,
28 UNKNOWN_ERROR_MESSAGE,
29)
30from couchers.context import CouchersContext, make_interactive_context, make_media_context
31from couchers.db import session_scope
32from couchers.descriptor_pool import get_descriptor_pool
33from couchers.metrics import observe_in_servicer_duration_histogram
34from couchers.models import APICall, User, UserActivity, UserSession
35from couchers.proto import annotations_pb2
36from couchers.proto.annotations_pb2 import AuthLevel
37from couchers.utils import (
38 create_lang_cookie,
39 create_session_cookies,
40 now,
41 parse_api_key,
42 parse_session_cookie,
43 parse_ui_lang_cookie,
44 parse_user_id_cookie,
45)
47logger = logging.getLogger(__name__)
50@dataclass(frozen=True, slots=True, kw_only=True)
51class UserAuthInfo:
52 """Information about an authenticated user session."""
54 user_id: int
55 is_jailed: bool
56 is_editor: bool
57 is_superuser: bool
58 token_expiry: datetime
59 ui_language_preference: str | None
60 token: str = field(repr=False)
61 is_api_key: bool
64def _binned_now() -> Function[Any]:
65 return func.date_bin("1 hour", func.now(), "2000-01-01")
68def _try_get_and_update_user_details(
69 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
70) -> UserAuthInfo | None:
71 """
72 Tries to get session and user info corresponding to this token.
74 Also updates the user's last active time, token last active time, and increments API call count.
76 Returns UserAuthInfo if a valid session is found, None otherwise.
77 """
78 if not token:
79 return None
81 with session_scope() as session:
82 result = session.execute(
83 select(User, UserSession, UserActivity)
84 .select_from(UserSession)
85 .join(User, User.id == UserSession.user_id)
86 .outerjoin(
87 UserActivity,
88 and_(
89 UserActivity.user_id == User.id,
90 UserActivity.period == _binned_now(),
91 UserActivity.ip_address == ip_address,
92 UserActivity.user_agent == user_agent,
93 ),
94 )
95 .where(User.is_visible)
96 .where(UserSession.token == token)
97 .where(UserSession.is_valid)
98 .where(UserSession.is_api_key == is_api_key)
99 ).one_or_none()
101 if not result:
102 return None
104 user, user_session, user_activity = result._tuple()
106 # update user last active time if it's been a while
107 if now() - user.last_active > timedelta(minutes=5):
108 user.last_active = func.now()
110 # let's update the token
111 user_session.last_seen = func.now()
112 user_session.api_calls += 1
114 if user_activity:
115 user_activity.api_calls += 1
116 else:
117 session.add(
118 UserActivity(
119 user_id=user.id,
120 period=_binned_now(),
121 ip_address=ip_address,
122 user_agent=user_agent,
123 api_calls=1,
124 )
125 )
127 session.commit()
129 return UserAuthInfo(
130 user_id=user.id,
131 is_jailed=user.is_jailed,
132 is_editor=user.is_editor,
133 is_superuser=user.is_superuser,
134 token_expiry=user_session.expiry,
135 ui_language_preference=user.ui_language_preference,
136 token=token,
137 is_api_key=is_api_key,
138 )
141def abort_handler[T, R](
142 message: str,
143 status_code: grpc.StatusCode,
144) -> grpc.RpcMethodHandler[T, R]:
145 def f(request: Any, context: CouchersContext) -> NoReturn:
146 context.abort(status_code, message)
148 return grpc.unary_unary_rpc_method_handler(f)
151def unauthenticated_handler[T, R](
152 message: str = UNAUTHORIZED_ERROR_MESSAGE,
153 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
154) -> grpc.RpcMethodHandler[T, R]:
155 return abort_handler(message, status_code)
158@overload
159def _sanitized_bytes(proto: Message) -> bytes: ...
160@overload
161def _sanitized_bytes(proto: None) -> None: ...
162def _sanitized_bytes(proto: Message | None) -> bytes | None:
163 """
164 Remove fields marked sensitive and return serialized bytes
165 """
166 if not proto:
167 return None
169 new_proto = deepcopy(proto)
171 def _sanitize_message(message: Message) -> None:
172 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
173 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
174 message.ClearField(name)
175 if descriptor.message_type:
176 submessage = getattr(message, name)
177 if not submessage:
178 continue
179 if descriptor.is_repeated:
180 for msg in submessage:
181 _sanitize_message(msg)
182 else:
183 _sanitize_message(submessage)
185 _sanitize_message(new_proto)
187 return new_proto.SerializeToString()
190def _store_log(
191 *,
192 method: str,
193 status_code: str | None = None,
194 duration: float,
195 user_id: int | None,
196 is_api_key: bool,
197 request: Message,
198 response: Message | None,
199 traceback: str | None = None,
200 perf_report: str | None = None,
201 ip_address: str | None,
202 user_agent: str | None,
203) -> None:
204 req_bytes = _sanitized_bytes(request)
205 res_bytes = _sanitized_bytes(response)
206 with session_scope() as session:
207 response_truncated = False
208 truncate_res_bytes_length = 16 * 1024 # 16 kB
209 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true
210 res_bytes = res_bytes[:truncate_res_bytes_length]
211 response_truncated = True
212 session.add(
213 APICall(
214 is_api_key=is_api_key,
215 method=method,
216 status_code=status_code,
217 duration=duration,
218 user_id=user_id,
219 request=req_bytes,
220 response=res_bytes,
221 response_truncated=response_truncated,
222 traceback=traceback,
223 perf_report=perf_report,
224 ip_address=ip_address,
225 user_agent=user_agent,
226 )
227 )
228 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
231type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
234class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
235 """
236 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
238 Sets context.user_id and context.token if authenticated, otherwise
239 terminates the call with an UNAUTHENTICATED error code.
241 2. Makes sure cookies are in sync.
243 3. Injects a session to get a database transaction.
245 4. Measures and logs the time it takes to service each incoming call.
246 """
248 def __init__(self) -> None:
249 self._pool = get_descriptor_pool()
251 def intercept_service[T = Message, R = Message](
252 self,
253 continuation: Cont[T, R],
254 handler_call_details: grpc.HandlerCallDetails,
255 ) -> grpc.RpcMethodHandler[T, R]:
256 start = perf_counter_ns()
258 method = handler_call_details.method
260 try:
261 auth_level = find_auth_level(self._pool, method)
262 except AbortError as ae:
263 return abort_handler(ae.msg, ae.code)
265 try:
266 headers = parse_headers(dict(handler_call_details.invocation_metadata))
267 except BadHeaders:
268 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
270 auth_info = _try_get_and_update_user_details(
271 headers.token, headers.is_api_key, headers.ip_address, headers.user_agent
272 )
274 try:
275 check_permissions(auth_info, auth_level)
276 except AbortError as ae:
277 return unauthenticated_handler(ae.msg, ae.code)
279 if not (handler := continuation(handler_call_details)): 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true
280 raise RuntimeError(f"No handler in '{method}'")
282 if not (prev_function := handler.unary_unary): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 raise RuntimeError(f"No prev_function in '{method}', {handler}")
285 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
286 couchers_context = make_interactive_context(
287 grpc_context=grpc_context,
288 user_id=auth_info.user_id if auth_info else None,
289 is_api_key=auth_info.is_api_key if auth_info else False,
290 token=auth_info.token if auth_info else None,
291 ui_language_preference=auth_info.ui_language_preference if auth_info else None,
292 )
294 with session_scope() as session:
295 try:
296 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
297 res = cast(Message, _res)
298 finished = perf_counter_ns()
299 duration = (finished - start) / 1e6 # ms
300 _store_log(
301 method=method,
302 duration=duration,
303 user_id=couchers_context._user_id,
304 is_api_key=cast(bool, couchers_context._is_api_key),
305 request=req,
306 response=res,
307 ip_address=headers.ip_address,
308 user_agent=headers.user_agent,
309 )
310 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
311 except Exception as e:
312 finished = perf_counter_ns()
313 duration = (finished - start) / 1e6 # ms
315 if couchers_context._grpc_context: 315 ↛ 319line 315 didn't jump to line 319 because the condition on line 315 was always true
316 context_code = couchers_context._grpc_context.code() # type: ignore[attr-defined]
317 code = getattr(context_code, "name", None)
318 else:
319 code = None
321 traceback = "".join(format_exception(type(e), e, e.__traceback__))
322 _store_log(
323 method=method,
324 status_code=code,
325 duration=duration,
326 user_id=couchers_context._user_id,
327 is_api_key=cast(bool, couchers_context._is_api_key),
328 request=req,
329 response=None,
330 traceback=traceback,
331 ip_address=headers.ip_address,
332 user_agent=headers.user_agent,
333 )
334 observe_in_servicer_duration_histogram(
335 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
336 )
338 if not code:
339 sentry_sdk.set_tag("context", "servicer")
340 sentry_sdk.set_tag("method", method)
341 sentry_sdk.capture_exception(e)
343 raise e
345 if auth_info and not auth_info.is_api_key:
346 # check the two cookies are in sync & that language preference cookie is correct
347 if headers.user_id != str(auth_info.user_id): 347 ↛ 351line 347 didn't jump to line 351 because the condition on line 347 was always true
348 couchers_context.set_cookies(
349 create_session_cookies(auth_info.token, auth_info.user_id, auth_info.token_expiry)
350 )
351 if auth_info.ui_language_preference and auth_info.ui_language_preference != headers.ui_lang:
352 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
354 if not grpc_context.is_active(): 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
357 couchers_context._send_cookies()
359 return res
361 return grpc.unary_unary_rpc_method_handler(
362 function_without_couchers_stuff,
363 request_deserializer=handler.request_deserializer,
364 response_serializer=handler.response_serializer,
365 )
368@dataclass(frozen=True, slots=True, kw_only=True)
369class CouchersHeaders:
370 token: str | None = field(repr=False)
371 is_api_key: bool
372 ip_address: str | None
373 user_agent: str | None
374 ui_lang: str | None
375 user_id: str | None
378def parse_headers(headers: Mapping[str, str | bytes]) -> CouchersHeaders:
379 if "cookie" in headers and "authorization" in headers:
380 # for security reasons, only one of "cookie" or "authorization" can be present
381 raise BadHeaders("Both cookies and authorization are present in headers")
382 elif "cookie" in headers:
383 # the session token is passed in cookies, i.e., in the `cookie` header
384 token, is_api_key = parse_session_cookie(headers), False
385 elif "authorization" in headers:
386 # the session token is passed in the `authorization` header
387 token, is_api_key = parse_api_key(headers), True
388 else:
389 # no session found
390 token, is_api_key = None, False
392 ip_address = headers.get("x-couchers-real-ip")
393 user_agent = headers.get("user-agent")
395 ui_lang = parse_ui_lang_cookie(headers)
396 user_id = parse_user_id_cookie(headers)
398 return CouchersHeaders(
399 token=token,
400 is_api_key=is_api_key,
401 ip_address=ip_address if isinstance(ip_address, str) else None,
402 user_agent=user_agent if isinstance(user_agent, str) else None,
403 ui_lang=ui_lang,
404 user_id=user_id,
405 )
408class BadHeaders(Exception):
409 pass
412class AbortError(Exception):
413 def __init__(self, msg: str, code: grpc.StatusCode):
414 self.msg = msg
415 self.code = code
418def find_auth_level(pool: DescriptorPool, method: str) -> AuthLevel.ValueType:
419 # method is of the form "/org.couchers.api.core.API/GetUser"
420 _, service_name, method_name = method.split("/")
422 try:
423 service: ServiceDescriptor = pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
424 service_options = service.GetOptions()
425 except KeyError:
426 raise AbortError(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED) from None
428 level = service_options.Extensions[annotations_pb2.auth_level]
430 validate_auth_level(level)
432 return level
435def validate_auth_level(auth_level: AuthLevel.ValueType) -> None:
436 # if unknown auth level, then it wasn't set and something's wrong
437 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
438 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
440 if auth_level not in { 440 ↛ 447line 440 didn't jump to line 447 because the condition on line 440 was never true
441 annotations_pb2.AUTH_LEVEL_OPEN,
442 annotations_pb2.AUTH_LEVEL_JAILED,
443 annotations_pb2.AUTH_LEVEL_SECURE,
444 annotations_pb2.AUTH_LEVEL_EDITOR,
445 annotations_pb2.AUTH_LEVEL_ADMIN,
446 }:
447 raise AbortError(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
450def check_permissions(auth_info: UserAuthInfo | None, auth_level: AuthLevel.ValueType) -> None:
451 if not auth_info:
452 # if this isn't an open service, fail
453 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
454 raise AbortError(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
455 else:
456 # a valid user session was found - check permissions
457 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
458 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
460 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
461 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
463 # if the user is jailed and this isn't an open or jailed service, fail
464 if auth_info.is_jailed and auth_level not in [
465 annotations_pb2.AUTH_LEVEL_OPEN,
466 annotations_pb2.AUTH_LEVEL_JAILED,
467 ]:
468 raise AbortError(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
471class MediaInterceptor(grpc.ServerInterceptor):
472 """
473 Extracts an "Authorization: Bearer <hex>" header and calls the
474 is_authorized function. Terminates the call with an HTTP error
475 code if not authorized.
477 Also adds a session to called APIs.
478 """
480 def __init__(self, is_authorized: Callable[[str], bool]):
481 self._is_authorized = is_authorized
483 def intercept_service[T, R](
484 self,
485 continuation: Cont[T, R],
486 handler_call_details: grpc.HandlerCallDetails,
487 ) -> grpc.RpcMethodHandler[T, R]:
488 handler = continuation(handler_call_details)
489 if not handler: 489 ↛ 490line 489 didn't jump to line 490 because the condition on line 489 was never true
490 raise RuntimeError("No handler")
492 prev_func = handler.unary_unary
493 if not prev_func: 493 ↛ 494line 493 didn't jump to line 494 because the condition on line 493 was never true
494 raise RuntimeError(f"No prev_function, {handler}")
496 metadata = dict(handler_call_details.invocation_metadata)
498 token = parse_api_key(metadata)
500 if not token or not self._is_authorized(token): 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true
501 return unauthenticated_handler()
503 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
504 with session_scope() as session:
505 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
507 return grpc.unary_unary_rpc_method_handler(
508 function_without_session,
509 request_deserializer=handler.request_deserializer,
510 response_serializer=handler.response_serializer,
511 )
514class OTelInterceptor(grpc.ServerInterceptor):
515 """
516 OpenTelemetry tracing
517 """
519 def __init__(self) -> None:
520 self.tracer = trace.get_tracer(__name__)
522 def intercept_service[T, R](
523 self,
524 continuation: Cont[T, R],
525 handler_call_details: grpc.HandlerCallDetails,
526 ) -> grpc.RpcMethodHandler[T, R]:
527 handler = continuation(handler_call_details)
528 if not handler:
529 raise RuntimeError("No handler")
531 prev_func = handler.unary_unary
532 if not prev_func:
533 raise RuntimeError(f"No prev_function, {handler}")
535 method = handler_call_details.method
537 # method is of the form "/org.couchers.api.core.API/GetUser"
538 _, service_name, method_name = method.split("/")
540 headers = dict(handler_call_details.invocation_metadata)
542 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
543 with self.tracer.start_as_current_span("handler") as rollspan:
544 rollspan.set_attribute("rpc.method_full", method)
545 rollspan.set_attribute("rpc.service", service_name)
546 rollspan.set_attribute("rpc.method", method_name)
548 rollspan.set_attribute("rpc.thread", get_ident())
549 rollspan.set_attribute("rpc.pid", getpid())
551 res = prev_func(request, context)
553 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
554 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
556 return res
558 return grpc.unary_unary_rpc_method_handler(
559 tracing_function,
560 request_deserializer=handler.request_deserializer,
561 response_serializer=handler.response_serializer,
562 )
565class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
566 """
567 If the call resulted in a non-gRPC error, this strips away the error details.
569 It's important to put this first, so that it does not interfere with other interceptors.
570 """
572 def intercept_service[T, R](
573 self,
574 continuation: Cont[T, R],
575 handler_call_details: grpc.HandlerCallDetails,
576 ) -> grpc.RpcMethodHandler[T, R]:
577 handler = continuation(handler_call_details)
578 if not handler: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true
579 raise RuntimeError("No handler")
581 prev_func = handler.unary_unary
582 if not prev_func: 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true
583 raise RuntimeError(f"No prev_function, {handler}")
585 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
586 try:
587 res = prev_func(req, context)
588 except Exception as e:
589 code = context.code() # type: ignore[attr-defined]
590 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
591 if not code:
592 logger.exception(e)
593 logger.info("Probably an unknown error! Sanitizing...")
594 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
595 else:
596 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
597 raise e
598 return res
600 return grpc.unary_unary_rpc_method_handler(
601 sanitizing_function,
602 request_deserializer=handler.request_deserializer,
603 response_serializer=handler.response_serializer,
604 )