Coverage for src/couchers/interceptors.py: 86%
226 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-18 14:01 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-18 14:01 +0000
1import logging
2from collections.abc import Callable
3from copy import deepcopy
4from datetime import datetime, timedelta
5from os import getpid
6from threading import get_ident
7from time import perf_counter_ns
8from traceback import format_exception
9from typing import Any, Never, NoReturn, cast
11import grpc
12import sentry_sdk
13from google.protobuf.descriptor import ServiceDescriptor
14from google.protobuf.message import Message
15from opentelemetry import trace
16from sqlalchemy import Function
17from sqlalchemy.sql import and_, func
19from couchers.constants import UNKNOWN_ERROR_MESSAGE
20from couchers.context import CouchersContext, make_interactive_context, make_media_context
21from couchers.db import session_scope
22from couchers.descriptor_pool import get_descriptor_pool
23from couchers.metrics import observe_in_servicer_duration_histogram
24from couchers.models import APICall, User, UserActivity, UserSession
25from couchers.proto import annotations_pb2
26from couchers.sql import couchers_select as select
27from couchers.utils import (
28 create_lang_cookie,
29 create_session_cookies,
30 now,
31 parse_api_key,
32 parse_session_cookie,
33 parse_ui_lang_cookie,
34 parse_user_id_cookie,
35)
37logger = logging.getLogger(__name__)
40def _binned_now() -> Function[Any]:
41 return func.date_bin("1 hour", func.now(), "2000-01-01")
44def _try_get_and_update_user_details(
45 token: str | None, is_api_key: bool, ip_address: str | None, user_agent: str | None
46) -> tuple[int, bool, bool, datetime, str | None] | None:
47 """
48 Tries to get session and user info corresponding to this token.
50 Also updates the user's last active time, token last active time, and increments API call count.
51 """
52 if not token:
53 return None
55 with session_scope() as session:
56 result = session.execute(
57 select(User, UserSession, UserActivity)
58 .join(User, User.id == UserSession.user_id)
59 .outerjoin(
60 UserActivity,
61 and_(
62 UserActivity.user_id == User.id,
63 UserActivity.period == _binned_now(),
64 UserActivity.ip_address == ip_address,
65 UserActivity.user_agent == user_agent,
66 ),
67 )
68 .where(User.is_visible)
69 .where(UserSession.token == token)
70 .where(UserSession.is_valid)
71 .where(UserSession.is_api_key == is_api_key)
72 ).one_or_none()
74 if not result:
75 return None
76 else:
77 user, user_session, user_activity = result
79 # update user last active time if it's been a while
80 if now() - user.last_active > timedelta(minutes=5):
81 user.last_active = func.now()
83 # let's update the token
84 user_session.last_seen = func.now()
85 user_session.api_calls += 1
87 if user_activity:
88 user_activity.api_calls += 1
89 else:
90 session.add(
91 UserActivity(
92 user_id=user.id,
93 period=_binned_now(),
94 ip_address=ip_address,
95 user_agent=user_agent,
96 api_calls=1,
97 )
98 )
100 session.commit()
102 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference
105# We have to lie with R | NoReturn to please mypy. It should be NoReturn.
106def abort_handler[T, R](
107 message: str,
108 status_code: grpc.StatusCode,
109) -> "grpc.RpcMethodHandler[T, R | NoReturn]":
110 def f(request: Any, context: CouchersContext) -> NoReturn:
111 context.abort(status_code, message)
113 return grpc.unary_unary_rpc_method_handler(f)
116def unauthenticated_handler[T, R](
117 message: str = "Unauthorized",
118 status_code: grpc.StatusCode = grpc.StatusCode.UNAUTHENTICATED,
119) -> "grpc.RpcMethodHandler[T, R | NoReturn]":
120 return abort_handler(message, status_code)
123def _sanitized_bytes(proto: Message | None) -> bytes | None:
124 """
125 Remove fields marked sensitive and return serialized bytes
126 """
127 if not proto:
128 return None
130 new_proto = deepcopy(proto)
132 def _sanitize_message(message: Message) -> None:
133 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
134 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
135 message.ClearField(name)
136 if descriptor.message_type:
137 submessage = getattr(message, name)
138 if not submessage:
139 continue
140 if descriptor.label == descriptor.LABEL_REPEATED:
141 for msg in submessage:
142 _sanitize_message(msg)
143 else:
144 _sanitize_message(submessage)
146 _sanitize_message(new_proto)
148 return new_proto.SerializeToString()
151def _store_log(
152 *,
153 method: str,
154 status_code: grpc.StatusCode | None,
155 duration: float,
156 user_id: int | None,
157 is_api_key: bool,
158 request: Message,
159 response: Message | None,
160 traceback: str | None,
161 perf_report: str | None,
162 ip_address: str | None,
163 user_agent: str | None,
164) -> None:
165 req_bytes = _sanitized_bytes(request)
166 res_bytes = _sanitized_bytes(response)
167 with session_scope() as session:
168 response_truncated = False
169 truncate_res_bytes_length = 16 * 1024 # 16 kB
170 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
171 res_bytes = res_bytes[:truncate_res_bytes_length]
172 response_truncated = True
173 session.add(
174 APICall(
175 is_api_key=is_api_key,
176 method=method,
177 status_code=status_code,
178 duration=duration,
179 user_id=user_id,
180 request=req_bytes,
181 response=res_bytes,
182 response_truncated=response_truncated,
183 traceback=traceback,
184 perf_report=perf_report,
185 ip_address=ip_address,
186 user_agent=user_agent,
187 )
188 )
189 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
192type Cont[T, R] = Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[T, R] | None]
195class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
196 """
197 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
199 Sets context.user_id and context.token if authenticated, otherwise
200 terminates the call with an UNAUTHENTICATED error code.
202 2. Makes sure cookies are in sync.
204 3. Injects a session to get a database transaction.
206 4. Measures and logs the time it takes to service each incoming call.
207 """
209 def __init__(self) -> None:
210 self._pool = get_descriptor_pool()
212 def intercept_service[T = Message, R = Message](
213 self,
214 continuation: Cont[T, R],
215 handler_call_details: grpc.HandlerCallDetails,
216 ) -> "grpc.RpcMethodHandler[T, R | Never]":
217 start = perf_counter_ns()
219 method = handler_call_details.method
220 # method is of the form "/org.couchers.api.core.API/GetUser"
221 _, service_name, method_name = method.split("/")
223 try:
224 service: ServiceDescriptor = self._pool.FindServiceByName(service_name) # type: ignore[no-untyped-call]
225 service_options = service.GetOptions()
226 except KeyError:
227 return abort_handler(
228 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
229 )
231 auth_level: Any = service_options.Extensions[annotations_pb2.auth_level] # type: ignore[index]
233 # if unknown auth level, then it wasn't set and something's wrong
234 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
235 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
237 assert auth_level in [
238 annotations_pb2.AUTH_LEVEL_OPEN,
239 annotations_pb2.AUTH_LEVEL_JAILED,
240 annotations_pb2.AUTH_LEVEL_SECURE,
241 annotations_pb2.AUTH_LEVEL_ADMIN,
242 ]
244 headers = dict(handler_call_details.invocation_metadata)
246 if "cookie" in headers and "authorization" in headers:
247 # for security reasons, only one of "cookie" or "authorization" can be present
248 return unauthenticated_handler('Both "cookie" and "authorization" in request')
249 elif "cookie" in headers:
250 # the session token is passed in cookies, i.e. in the `cookie` header
251 token, is_api_key = parse_session_cookie(headers), False
252 elif "authorization" in headers:
253 # the session token is passed in the `authorization` header
254 token, is_api_key = parse_api_key(headers), True
255 else:
256 # no session found
257 token, is_api_key = None, False
259 ip_address = cast(str | None, headers.get("x-couchers-real-ip"))
260 user_agent = cast(str | None, headers.get("user-agent"))
262 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
263 # auth_info is now filled if and only if this is a valid session
264 if not auth_info:
265 token = None
266 is_api_key = False
267 token_expiry = None
268 user_id = None
269 ui_language_preference = None
271 # if no session was found and this isn't an open service, fail
272 if not auth_info:
273 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
274 # NOTE: do not translate this string; it's used in a hacky way in the frontend
275 return unauthenticated_handler("Unauthorized")
276 else:
277 # a valid user session was found
278 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info
280 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
281 # NOTE: do not translate this string; it's used in a hacky way in the frontend
282 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
284 # if the user is jailed and this is isn't an open or jailed service, fail
285 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
286 # NOTE: do not translate this string; it's used in a hacky way in the frontend
287 return unauthenticated_handler("Permission denied")
289 handler = continuation(handler_call_details)
290 if not handler:
291 raise RuntimeError(f"No handler in '{method}'")
293 prev_function = handler.unary_unary
294 if not prev_function:
295 raise RuntimeError(f"No prev_function in '{method}', {handler}")
297 def function_without_couchers_stuff(req: Message, grpc_context: grpc.ServicerContext) -> Message | None:
298 couchers_context: CouchersContext = make_interactive_context(
299 grpc_context=grpc_context,
300 user_id=user_id,
301 is_api_key=is_api_key,
302 token=token,
303 ui_language_preference=ui_language_preference,
304 )
305 with session_scope() as session:
306 try:
307 _res = prev_function(req, couchers_context, session) # type: ignore[call-arg, arg-type]
308 res = cast(Message, _res)
309 finished = perf_counter_ns()
310 duration = (finished - start) / 1e6 # ms
311 _store_log(
312 method=method,
313 status_code=None,
314 duration=duration,
315 user_id=couchers_context._user_id,
316 is_api_key=cast(bool, couchers_context._is_api_key),
317 request=req,
318 response=res,
319 traceback=None,
320 perf_report=None,
321 ip_address=ip_address,
322 user_agent=user_agent,
323 )
324 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
325 except Exception as e:
326 finished = perf_counter_ns()
327 duration = (finished - start) / 1e6 # ms
328 code = getattr(couchers_context._grpc_context.code(), "name", None) # type: ignore[union-attr]
329 traceback = "".join(format_exception(type(e), e, e.__traceback__))
330 _store_log(
331 method=method,
332 status_code=code,
333 duration=duration,
334 user_id=couchers_context._user_id,
335 is_api_key=cast(bool, couchers_context._is_api_key),
336 request=req,
337 response=None,
338 traceback=traceback,
339 perf_report=None,
340 ip_address=ip_address,
341 user_agent=user_agent,
342 )
343 observe_in_servicer_duration_histogram(
344 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
345 )
347 if not code:
348 sentry_sdk.set_tag("context", "servicer")
349 sentry_sdk.set_tag("method", method)
350 sentry_sdk.capture_exception(e)
352 raise e
354 if user_id and not is_api_key:
355 # Sanity check. If user_id is present, then we should have a token.
356 if token is None or token_expiry is None:
357 raise RuntimeError(f"{token=}, {token_expiry=}")
359 # check the two cookies are in sync & that language preference cookie is correct
360 if parse_user_id_cookie(headers) != str(user_id):
361 couchers_context.set_cookies(create_session_cookies(token, user_id, token_expiry))
362 if ui_language_preference and ui_language_preference != parse_ui_lang_cookie(headers):
363 couchers_context.set_cookies(create_lang_cookie(ui_language_preference))
365 if not grpc_context.is_active():
366 grpc_context.abort(grpc.StatusCode.INTERNAL, "Call cancelled.")
368 couchers_context._send_cookies()
370 return res
372 return grpc.unary_unary_rpc_method_handler(
373 function_without_couchers_stuff,
374 request_deserializer=handler.request_deserializer,
375 response_serializer=handler.response_serializer,
376 )
379class MediaInterceptor(grpc.ServerInterceptor):
380 """
381 Extracts an "Authorization: Bearer <hex>" header and calls the
382 is_authorized function. Terminates the call with an HTTP error
383 code if not authorized.
385 Also adds a session to called APIs.
386 """
388 def __init__(self, is_authorized: Callable[[str], bool]):
389 self._is_authorized = is_authorized
391 def intercept_service[T, R](
392 self,
393 continuation: Cont[T, R],
394 handler_call_details: grpc.HandlerCallDetails,
395 ) -> "grpc.RpcMethodHandler[T, R | Never]":
396 handler = continuation(handler_call_details)
397 if not handler:
398 raise RuntimeError("No handler")
400 prev_func = handler.unary_unary
401 if not prev_func:
402 raise RuntimeError(f"No prev_function, {handler}")
404 metadata = dict(handler_call_details.invocation_metadata)
406 token = parse_api_key(metadata)
408 if not token or not self._is_authorized(token):
409 return unauthenticated_handler()
411 def function_without_session(request: T, grpc_context: grpc.ServicerContext) -> R:
412 with session_scope() as session:
413 return prev_func(request, make_media_context(grpc_context), session) # type: ignore[call-arg, arg-type]
415 return grpc.unary_unary_rpc_method_handler(
416 function_without_session,
417 request_deserializer=handler.request_deserializer,
418 response_serializer=handler.response_serializer,
419 )
422class OTelInterceptor(grpc.ServerInterceptor):
423 """
424 OpenTelemetry tracing
425 """
427 def __init__(self) -> None:
428 self.tracer = trace.get_tracer(__name__)
430 def intercept_service[T, R](
431 self,
432 continuation: Cont[T, R],
433 handler_call_details: grpc.HandlerCallDetails,
434 ) -> "grpc.RpcMethodHandler[T, R | Never]":
435 handler = continuation(handler_call_details)
436 if not handler:
437 raise RuntimeError("No handler")
439 prev_func = handler.unary_unary
440 if not prev_func:
441 raise RuntimeError(f"No prev_function, {handler}")
443 method = handler_call_details.method
445 # method is of the form "/org.couchers.api.core.API/GetUser"
446 _, service_name, method_name = method.split("/")
448 headers = dict(handler_call_details.invocation_metadata)
450 def tracing_function(request: T, context: grpc.ServicerContext) -> R:
451 with self.tracer.start_as_current_span("handler") as rollspan:
452 rollspan.set_attribute("rpc.method_full", method)
453 rollspan.set_attribute("rpc.service", service_name)
454 rollspan.set_attribute("rpc.method", method_name)
456 rollspan.set_attribute("rpc.thread", get_ident())
457 rollspan.set_attribute("rpc.pid", getpid())
459 res = prev_func(request, context)
461 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
462 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
464 return res
466 return grpc.unary_unary_rpc_method_handler(
467 tracing_function,
468 request_deserializer=handler.request_deserializer,
469 response_serializer=handler.response_serializer,
470 )
473class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
474 """
475 If the call resulted in a non-gRPC error, this strips away the error details.
477 It's important to put this first, so that it does not interfere with other interceptors.
478 """
480 def intercept_service[T, R](
481 self,
482 continuation: Cont[T, R],
483 handler_call_details: grpc.HandlerCallDetails,
484 ) -> "grpc.RpcMethodHandler[T, R | Never]":
485 handler = continuation(handler_call_details)
486 if not handler:
487 raise RuntimeError("No handler")
489 prev_func = handler.unary_unary
490 if not prev_func:
491 raise RuntimeError(f"No prev_function, {handler}")
493 def sanitizing_function(req: T, context: grpc.ServicerContext) -> R:
494 try:
495 res = prev_func(req, context)
496 except Exception as e:
497 code = context.code() # type: ignore[attr-defined]
498 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
499 if not code:
500 logger.exception(e)
501 logger.info("Probably an unknown error! Sanitizing...")
502 context.abort(grpc.StatusCode.INTERNAL, UNKNOWN_ERROR_MESSAGE)
503 else:
504 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
505 raise e
506 return res
508 return grpc.unary_unary_rpc_method_handler(
509 sanitizing_function,
510 request_deserializer=handler.request_deserializer,
511 response_serializer=handler.response_serializer,
512 )