Coverage for src/couchers/interceptors.py: 86%
232 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-02 11:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-02 11:18 +0000
1import logging
2from collections.abc import Callable
3from copy import deepcopy
4from dataclasses import dataclass
5from datetime import datetime, timedelta
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
9from traceback import format_exception
10from typing import Any, Never, NoReturn, cast
12import grpc
13import sentry_sdk
14from google.protobuf.descriptor import ServiceDescriptor
15from google.protobuf.message import Message
16from opentelemetry import trace
17from sqlalchemy import Function
18from sqlalchemy.sql import and_, func
20from couchers.constants import (
21 CALL_CANCELLED_ERROR_MESSAGE,
22 COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE,
23 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
24 NONEXISTENT_API_CALL_ERROR_MESSAGE,
25 PERMISSION_DENIED_ERROR_MESSAGE,
26 UNAUTHORIZED_ERROR_MESSAGE,
27 UNKNOWN_ERROR_MESSAGE,
28)
29from couchers.context import CouchersContext, make_interactive_context, make_media_context
30from couchers.db import session_scope
31from couchers.descriptor_pool import get_descriptor_pool
32from couchers.metrics import observe_in_servicer_duration_histogram
33from couchers.models import APICall, User, UserActivity, UserSession
34from couchers.proto import annotations_pb2
35from couchers.sql import couchers_select as select
36from couchers.utils import (
37 create_lang_cookie,
38 create_session_cookies,
39 now,
40 parse_api_key,
41 parse_session_cookie,
42 parse_ui_lang_cookie,
43 parse_user_id_cookie,
44)
46logger = logging.getLogger(__name__)
49@dataclass(frozen=True, slots=True)
50class UserAuthInfo:
51 """
52 Information about an authenticated user session.
54 Returned by _try_get_and_update_user_details when a valid session is found.
55 """
57 user_id: int
58 is_jailed: bool
59 is_editor: bool
60 is_superuser: bool
61 token_expiry: datetime
62 ui_language_preference: str | None
65def _binned_now() -> Function[Any]:
66 return func.date_bin("1 hour", func.now(), "2000-01-01")
69def _try_get_and_update_user_details(
70 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
71) -> UserAuthInfo | None:
72 """
73 Tries to get session and user info corresponding to this token.
75 Also updates the user's last active time, token last active time, and increments API call count.
77 Returns UserAuthInfo if valid session found, None otherwise.
78 """
79 if not token:
80 return None
82 with session_scope() as session:
83 result = session.execute(
84 select(User, UserSession, UserActivity)
85 .join(User, User.id == UserSession.user_id)
86 .outerjoin(
87 UserActivity,
88 and_(
89 UserActivity.user_id == User.id,
90 UserActivity.period == _binned_now(),
91 UserActivity.ip_address == ip_address,
92 UserActivity.user_agent == user_agent,
93 ),
94 )
95 .where(User.is_visible)
96 .where(UserSession.token == token)
97 .where(UserSession.is_valid)
98 .where(UserSession.is_api_key == is_api_key)
99 ).one_or_none()
101 if not result:
102 return None
103 else:
104 user, user_session, user_activity = result
106 # update user last active time if it's been a while
107 if now() - user.last_active > timedelta(minutes=5):
108 user.last_active = func.now()
110 # let's update the token
111 user_session.last_seen = func.now()
112 user_session.api_calls += 1
114 if user_activity:
115 user_activity.api_calls += 1
116 else:
117 session.add(
118 UserActivity(
119 user_id=user.id,
120 period=_binned_now(),
121 ip_address=ip_address,
122 user_agent=user_agent,
123 api_calls=1,
124 )
125 )
127 session.commit()
129 return UserAuthInfo(
130 user_id=user.id,
131 is_jailed=user.is_jailed,
132 is_editor=user.is_editor,
133 is_superuser=user.is_superuser,
134 token_expiry=user_session.expiry,
135 ui_language_preference=user.ui_language_preference,
136 )
139# We have to lie with R | NoReturn to please mypy. It should be NoReturn.
140def abort_handler[T, R](
141 message: str,
142 status_code: grpc.StatusCode,
143) -> "grpc.RpcMethodHandler[T, R | NoReturn]":
144 def f(request: Any, context: CouchersContext) -> NoReturn:
145 context.abort(status_code, message)
147 return grpc.unary_unary_rpc_method_handler(f)
150def unauthenticated_handler[T, R](
151 message: str = UNAUTHORIZED_ERROR_MESSAGE,
152 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
153) -> "grpc.RpcMethodHandler[T, R | NoReturn]":
154 return abort_handler(message, status_code)
157def _sanitized_bytes(proto: Message | None) -> bytes | None:
158 """
159 Remove fields marked sensitive and return serialized bytes
160 """
161 if not proto:
162 return None
164 new_proto = deepcopy(proto)
166 def _sanitize_message(message: Message) -> None:
167 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
168 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
169 message.ClearField(name)
170 if descriptor.message_type:
171 submessage = getattr(message, name)
172 if not submessage:
173 continue
174 if descriptor.label == descriptor.LABEL_REPEATED:
175 for msg in submessage:
176 _sanitize_message(msg)
177 else:
178 _sanitize_message(submessage)
180 _sanitize_message(new_proto)
182 return new_proto.SerializeToString()
185def _store_log(
186 *,
187 method: str,
188 status_code: grpc.StatusCode | None,
189 duration: float,
190 user_id: int | None,
191 is_api_key: bool,
192 request: Message,
193 response: Message | None,
194 traceback: str | None,
195 perf_report: str | None,
196 ip_address: str | None,
197 user_agent: str | None,
198) -> None:
199 req_bytes = _sanitized_bytes(request)
200 res_bytes = _sanitized_bytes(response)
201 with session_scope() as session:
202 response_truncated = False
203 truncate_res_bytes_length = 16 * 1024 # 16 kB
204 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
205 res_bytes = res_bytes[:truncate_res_bytes_length]
206 response_truncated = True
207 session.add(
208 APICall(
209 is_api_key=is_api_key,
210 method=method,
211 status_code=status_code,
212 duration=duration,
213 user_id=user_id,
214 request=req_bytes,
215 response=res_bytes,
216 response_truncated=response_truncated,
217 traceback=traceback,
218 perf_report=perf_report,
219 ip_address=ip_address,
220 user_agent=user_agent,
221 )
222 )
223 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
226type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
229class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
230 """
231 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
233 Sets context.user_id and context.token if authenticated, otherwise
234 terminates the call with an UNAUTHENTICATED error code.
236 2. Makes sure cookies are in sync.
238 3. Injects a session to get a database transaction.
240 4. Measures and logs the time it takes to service each incoming call.
241 """
243 def __init__(self) -> None:
244 self._pool = get_descriptor_pool()
246 def intercept_service[T = Message, R = Message](
247 self,
248 continuation: Cont[T, R],
249 handler_call_details: grpc.HandlerCallDetails,
250 ) -> "grpc.RpcMethodHandler[T, R | Never]":
251 start = perf_counter_ns()
253 method = handler_call_details.method
254 # method is of the form "/org.couchers.api.core.API/GetUser"
255 _, service_name, method_name = method.split("/")
257 try:
258 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
259 service_options = service.GetOptions()
260 except KeyError:
261 return abort_handler(NONEXISTENT_API_CALL_ERROR_MESSAGE, grpc.StatusCode.UNIMPLEMENTED)
263 auth_level = service_options.Extensions[annotations_pb2.auth_level]
265 # if unknown auth level, then it wasn't set and something's wrong
266 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
267 return abort_handler(MISSING_AUTH_LEVEL_ERROR_MESSAGE, grpc.StatusCode.INTERNAL)
269 assert auth_level in [
270 annotations_pb2.AUTH_LEVEL_OPEN,
271 annotations_pb2.AUTH_LEVEL_JAILED,
272 annotations_pb2.AUTH_LEVEL_SECURE,
273 annotations_pb2.AUTH_LEVEL_EDITOR,
274 annotations_pb2.AUTH_LEVEL_ADMIN,
275 ]
277 headers = dict(handler_call_details.invocation_metadata)
279 if "cookie" in headers and "authorization" in headers:
280 # for security reasons, only one of "cookie" or "authorization" can be present
281 return unauthenticated_handler(COOKIES_AND_AUTH_HEADER_ERROR_MESSAGE)
282 elif "cookie" in headers:
283 # the session token is passed in cookies, i.e. in the `cookie` header
284 token, is_api_key = parse_session_cookie(headers), False
285 elif "authorization" in headers:
286 # the session token is passed in the `authorization` header
287 token, is_api_key = parse_api_key(headers), True
288 else:
289 # no session found
290 token, is_api_key = None, False
292 ip_address = cast(str | None, headers.get("x-couchers-real-ip"))
293 user_agent = cast(str | None, headers.get("user-agent"))
295 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
297 if not auth_info:
298 # Invalid or no session - clear credentials
299 token = None
300 is_api_key = False
302 # if this isn't an open service, fail
303 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
304 return unauthenticated_handler(UNAUTHORIZED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
305 else:
306 # a valid user session was found - check permissions
307 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not auth_info.is_superuser:
308 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
310 if auth_level == annotations_pb2.AUTH_LEVEL_EDITOR and not auth_info.is_editor:
311 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.PERMISSION_DENIED)
313 # if the user is jailed and this is isn't an open or jailed service, fail
314 if auth_info.is_jailed and auth_level not in [
315 annotations_pb2.AUTH_LEVEL_OPEN,
316 annotations_pb2.AUTH_LEVEL_JAILED,
317 ]:
318 return unauthenticated_handler(PERMISSION_DENIED_ERROR_MESSAGE, grpc.StatusCode.UNAUTHENTICATED)
320 handler = continuation(handler_call_details)
321 if not handler:
322 raise RuntimeError(f"No handler in '{method}'")
324 prev_function = handler.unary_unary
325 if not prev_function:
326 raise RuntimeError(f"No prev_function in '{method}', {handler}")
328 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
329 couchers_context: CouchersContext = make_interactive_context(
330 grpc_context=grpc_context,
331 user_id=auth_info.user_id if auth_info else None,
332 is_api_key=is_api_key,
333 token=token,
334 ui_language_preference=auth_info.ui_language_preference if auth_info else None,
335 )
336 with session_scope() as session:
337 try:
338 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
339 res = cast(Message, _res)
340 finished = perf_counter_ns()
341 duration = (finished - start) / 1e6 # ms
342 _store_log(
343 method=method,
344 status_code=None,
345 duration=duration,
346 user_id=couchers_context._user_id,
347 is_api_key=cast(bool, couchers_context._is_api_key),
348 request=req,
349 response=res,
350 traceback=None,
351 perf_report=None,
352 ip_address=ip_address,
353 user_agent=user_agent,
354 )
355 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
356 except Exception as e:
357 finished = perf_counter_ns()
358 duration = (finished - start) / 1e6 # ms
359 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr]
360 traceback = "".join(format_exception(type(e), e, e.__traceback__))
361 _store_log(
362 method=method,
363 status_code=code,
364 duration=duration,
365 user_id=couchers_context._user_id,
366 is_api_key=cast(bool, couchers_context._is_api_key),
367 request=req,
368 response=None,
369 traceback=traceback,
370 perf_report=None,
371 ip_address=ip_address,
372 user_agent=user_agent,
373 )
374 observe_in_servicer_duration_histogram(
375 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
376 )
378 if not code:
379 sentry_sdk.set_tag("context", "servicer")
380 sentry_sdk.set_tag("method", method)
381 sentry_sdk.capture_exception(e)
383 raise e
385 if auth_info and not is_api_key:
386 # Sanity check. If auth_info is present, then we should have a token.
387 if token is None:
388 raise RuntimeError(f"{token=}, {auth_info.token_expiry=}")
390 # check the two cookies are in sync & that language preference cookie is correct
391 if parse_user_id_cookie(headers) != str(auth_info.user_id):
392 couchers_context.set_cookies(
393 create_session_cookies(token, auth_info.user_id, auth_info.token_expiry)
394 )
395 if auth_info.ui_language_preference and auth_info.ui_language_preference != parse_ui_lang_cookie(
396 headers
397 ):
398 couchers_context.set_cookies(create_lang_cookie(auth_info.ui_language_preference))
400 if not grpc_context.is_active():
401 grpc_context.abort(grpc.StatusCode.INTERNAL, CALL_CANCELLED_ERROR_MESSAGE)
403 couchers_context._send_cookies()
405 return res
407 return grpc.unary_unary_rpc_method_handler(
408 function_without_couchers_stuff,
409 request_deserializer=handler.request_deserializer,
410 response_serializer=handler.response_serializer,
411 )
414class MediaInterceptor(grpc.ServerInterceptor):
415 """
416 Extracts an "Authorization: Bearer <hex>" header and calls the
417 is_authorized function. Terminates the call with an HTTP error
418 code if not authorized.
420 Also adds a session to called APIs.
421 """
423 def __init__(self, is_authorized: Callable[[str], bool]):
424 self._is_authorized = is_authorized
426 def intercept_service[T, R](
427 self,
428 continuation: Cont[T, R],
429 handler_call_details: grpc.HandlerCallDetails,
430 ) -> "grpc.RpcMethodHandler[T, R | Never]":
431 handler = continuation(handler_call_details)
432 if not handler:
433 raise RuntimeError("No handler")
435 prev_func = handler.unary_unary
436 if not prev_func:
437 raise RuntimeError(f"No prev_function, {handler}")
439 metadata = dict(handler_call_details.invocation_metadata)
441 token = parse_api_key(metadata)
443 if not token or not self._is_authorized(token):
444 return unauthenticated_handler()
446 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
447 with session_scope() as session:
448 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
450 return grpc.unary_unary_rpc_method_handler(
451 function_without_session,
452 request_deserializer=handler.request_deserializer,
453 response_serializer=handler.response_serializer,
454 )
457class OTelInterceptor(grpc.ServerInterceptor):
458 """
459 OpenTelemetry tracing
460 """
462 def __init__(self) -> None:
463 self.tracer = trace.get_tracer(__name__)
465 def intercept_service[T, R](
466 self,
467 continuation: Cont[T, R],
468 handler_call_details: grpc.HandlerCallDetails,
469 ) -> "grpc.RpcMethodHandler[T, R | Never]":
470 handler = continuation(handler_call_details)
471 if not handler:
472 raise RuntimeError("No handler")
474 prev_func = handler.unary_unary
475 if not prev_func:
476 raise RuntimeError(f"No prev_function, {handler}")
478 method = handler_call_details.method
480 # method is of the form "/org.couchers.api.core.API/GetUser"
481 _, service_name, method_name = method.split("/")
483 headers = dict(handler_call_details.invocation_metadata)
485 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
486 with self.tracer.start_as_current_span("handler") as rollspan:
487 rollspan.set_attribute("rpc.method_full", method)
488 rollspan.set_attribute("rpc.service", service_name)
489 rollspan.set_attribute("rpc.method", method_name)
491 rollspan.set_attribute("rpc.thread", get_ident())
492 rollspan.set_attribute("rpc.pid", getpid())
494 res = prev_func(request, context)
496 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
497 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
499 return res
501 return grpc.unary_unary_rpc_method_handler(
502 tracing_function,
503 request_deserializer=handler.request_deserializer,
504 response_serializer=handler.response_serializer,
505 )
508class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
509 """
510 If the call resulted in a non-gRPC error, this strips away the error details.
512 It's important to put this first, so that it does not interfere with other interceptors.
513 """
515 def intercept_service[T, R](
516 self,
517 continuation: Cont[T, R],
518 handler_call_details: grpc.HandlerCallDetails,
519 ) -> "grpc.RpcMethodHandler[T, R | Never]":
520 handler = continuation(handler_call_details)
521 if not handler:
522 raise RuntimeError("No handler")
524 prev_func = handler.unary_unary
525 if not prev_func:
526 raise RuntimeError(f"No prev_function, {handler}")
528 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
529 try:
530 res = prev_func(req, context)
531 except Exception as e:
532 code = context.code() # type: ignore[attr-defined]
533 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
534 if not code:
535 logger.exception(e)
536 logger.info("Probably an unknown error! Sanitizing...")
537 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
538 else:
539 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
540 raise e
541 return res
543 return grpc.unary_unary_rpc_method_handler(
544 sanitizing_function,
545 request_deserializer=handler.request_deserializer,
546 response_serializer=handler.response_serializer,
547 )