Coverage for src/couchers/interceptors.py: 88%
205 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 20:25 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-02 20:25 +0000
1import logging
2from copy import deepcopy
3from datetime import timedelta
4from os import getpid
5from threading import get_ident
6from time import perf_counter_ns
7from traceback import format_exception
9import grpc
10import sentry_sdk
11from opentelemetry import trace
12from sqlalchemy.sql import and_, func
14from couchers import errors
15from couchers.context import CouchersContext, make_interactive_user_context, make_media_context
16from couchers.db import session_scope
17from couchers.descriptor_pool import get_descriptor_pool
18from couchers.metrics import observe_in_servicer_duration_histogram
19from couchers.models import APICall, User, UserActivity, UserSession
20from couchers.sql import couchers_select as select
21from couchers.utils import (
22 create_lang_cookie,
23 create_session_cookies,
24 now,
25 parse_api_key,
26 parse_session_cookie,
27 parse_ui_lang_cookie,
28 parse_user_id_cookie,
29)
30from proto import annotations_pb2
32logger = logging.getLogger(__name__)
35def _binned_now():
36 return func.date_bin("1 hour", func.now(), "2000-01-01")
39def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent):
40 """
41 Tries to get session and user info corresponding to this token.
43 Also updates the user last active time, token last active time, and increments API call count.
44 """
45 if not token:
46 return None
48 with session_scope() as session:
49 result = session.execute(
50 select(User, UserSession, UserActivity)
51 .join(User, User.id == UserSession.user_id)
52 .outerjoin(
53 UserActivity,
54 and_(
55 UserActivity.user_id == User.id,
56 UserActivity.period == _binned_now(),
57 UserActivity.ip_address == ip_address,
58 UserActivity.user_agent == user_agent,
59 ),
60 )
61 .where(User.is_visible)
62 .where(UserSession.token == token)
63 .where(UserSession.is_valid)
64 .where(UserSession.is_api_key == is_api_key)
65 ).one_or_none()
67 if not result:
68 return None
69 else:
70 user, user_session, user_activity = result
72 # update user last active time if it's been a while
73 if now() - user.last_active > timedelta(minutes=5):
74 user.last_active = func.now()
76 # let's update the token
77 user_session.last_seen = func.now()
78 user_session.api_calls += 1
80 if user_activity:
81 user_activity.api_calls += 1
82 else:
83 session.add(
84 UserActivity(
85 user_id=user.id,
86 period=_binned_now(),
87 ip_address=ip_address,
88 user_agent=user_agent,
89 api_calls=1,
90 )
91 )
93 session.commit()
95 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference
98def abort_handler(message, status_code):
99 def f(request, context):
100 context.abort(status_code, message)
102 return grpc.unary_unary_rpc_method_handler(f)
105def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED):
106 return abort_handler(message, status_code)
109def _sanitized_bytes(proto):
110 """
111 Remove fields marked sensitive and return serialized bytes
112 """
113 if not proto:
114 return None
116 new_proto = deepcopy(proto)
118 def _sanitize_message(message):
119 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
120 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
121 message.ClearField(name)
122 if descriptor.message_type:
123 submessage = getattr(message, name)
124 if not submessage:
125 continue
126 if descriptor.label == descriptor.LABEL_REPEATED:
127 for msg in submessage:
128 _sanitize_message(msg)
129 else:
130 _sanitize_message(submessage)
132 _sanitize_message(new_proto)
134 return new_proto.SerializeToString()
137def _store_log(
138 *,
139 method,
140 status_code,
141 duration,
142 user_id,
143 is_api_key,
144 request,
145 response,
146 traceback,
147 perf_report,
148 ip_address,
149 user_agent,
150):
151 req_bytes = _sanitized_bytes(request)
152 res_bytes = _sanitized_bytes(response)
153 with session_scope() as session:
154 response_truncated = False
155 truncate_res_bytes_length = 16 * 1024 # 16 kB
156 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
157 res_bytes = res_bytes[:truncate_res_bytes_length]
158 response_truncated = True
159 session.add(
160 APICall(
161 is_api_key=is_api_key,
162 method=method,
163 status_code=status_code,
164 duration=duration,
165 user_id=user_id,
166 request=req_bytes,
167 response=res_bytes,
168 response_truncated=response_truncated,
169 traceback=traceback,
170 perf_report=perf_report,
171 ip_address=ip_address,
172 user_agent=user_agent,
173 )
174 )
175 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
178class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
179 """
180 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
182 Sets context.user_id and context.token if authenticated, otherwise
183 terminates the call with an UNAUTHENTICATED error code.
185 2. Makes sure cookies are in sync.
187 3. Injects a session to get a database transaction.
189 4. Measures and logs the time it takes to service each incoming call.
190 """
192 def __init__(self):
193 self._pool = get_descriptor_pool()
195 def intercept_service(self, continuation, handler_call_details):
196 start = perf_counter_ns()
198 method = handler_call_details.method
199 # method is of the form "/org.couchers.api.core.API/GetUser"
200 _, service_name, method_name = method.split("/")
202 try:
203 service_options = self._pool.FindServiceByName(service_name).GetOptions()
204 except KeyError:
205 return abort_handler(
206 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
207 )
209 auth_level = service_options.Extensions[annotations_pb2.auth_level]
211 # if unknown auth level, then it wasn't set and something's wrong
212 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
213 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
215 assert auth_level in [
216 annotations_pb2.AUTH_LEVEL_OPEN,
217 annotations_pb2.AUTH_LEVEL_JAILED,
218 annotations_pb2.AUTH_LEVEL_SECURE,
219 annotations_pb2.AUTH_LEVEL_ADMIN,
220 ]
222 headers = dict(handler_call_details.invocation_metadata)
224 if "cookie" in headers and "authorization" in headers:
225 # for security reasons, only one of "cookie" or "authorization" can be present
226 return unauthenticated_handler('Both "cookie" and "authorization" in request')
227 elif "cookie" in headers:
228 # the session token is passed in cookies, i.e. in the `cookie` header
229 token, is_api_key = parse_session_cookie(headers), False
230 elif "authorization" in headers:
231 # the session token is passed in the `authorization` header
232 token, is_api_key = parse_api_key(headers), True
233 else:
234 # no session found
235 token, is_api_key = None, False
237 ip_address = headers.get("x-couchers-real-ip")
238 user_agent = headers.get("user-agent")
240 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
241 # auth_info is now filled if and only if this is a valid session
242 if not auth_info:
243 token = None
244 is_api_key = False
245 token_expiry = None
246 user_id = None
247 ui_language_preference = None
249 # if no session was found and this isn't an open service, fail
250 if not auth_info:
251 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
252 return unauthenticated_handler()
253 else:
254 # a valid user session was found
255 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info
257 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
258 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
260 # if the user is jailed and this is isn't an open or jailed service, fail
261 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
262 return unauthenticated_handler("Permission denied")
264 handler = continuation(handler_call_details)
265 prev_function = handler.unary_unary
267 def function_without_couchers_stuff(req, grpc_context):
268 couchers_context: CouchersContext = make_interactive_user_context(
269 grpc_context=grpc_context,
270 user_id=user_id,
271 is_api_key=is_api_key,
272 token=token,
273 ui_language_preference=ui_language_preference,
274 )
275 with session_scope() as session:
276 try:
277 res = prev_function(req, couchers_context, session)
278 finished = perf_counter_ns()
279 duration = (finished - start) / 1e6 # ms
280 _store_log(
281 method=method,
282 status_code=None,
283 duration=duration,
284 user_id=couchers_context._user_id,
285 is_api_key=couchers_context._is_api_key,
286 request=req,
287 response=res,
288 traceback=None,
289 perf_report=None,
290 ip_address=ip_address,
291 user_agent=user_agent,
292 )
293 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
294 except Exception as e:
295 finished = perf_counter_ns()
296 duration = (finished - start) / 1e6 # ms
297 code = getattr(couchers_context._grpc_context.code(), "name", None)
298 traceback = "".join(format_exception(type(e), e, e.__traceback__))
299 _store_log(
300 method=method,
301 status_code=code,
302 duration=duration,
303 user_id=couchers_context._user_id,
304 is_api_key=couchers_context._is_api_key,
305 request=req,
306 response=None,
307 traceback=traceback,
308 perf_report=None,
309 ip_address=ip_address,
310 user_agent=user_agent,
311 )
312 observe_in_servicer_duration_histogram(
313 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
314 )
316 if not code:
317 sentry_sdk.set_tag("context", "servicer")
318 sentry_sdk.set_tag("method", method)
319 sentry_sdk.capture_exception(e)
321 raise e
323 if user_id and not is_api_key:
324 cookies = []
326 # check the two cookies are in sync & that language preference cookie is correct
327 if parse_user_id_cookie(headers) != str(user_id):
328 couchers_context.set_cookies(create_session_cookies(token, user_id, token_expiry))
329 if ui_language_preference and ui_language_preference != parse_ui_lang_cookie(headers):
330 couchers_context.set_cookies(create_lang_cookie(ui_language_preference))
332 try:
333 couchers_context._send_cookies()
334 except grpc.RpcError as e:
335 # Log details when client disconnects during cookie sending
336 # Some RpcErrors don't have code() or details() methods, so use getattr
337 code = getattr(e, "code", lambda: "unknown")()
338 details = getattr(e, "details", lambda: "unknown")()
339 logger.exception(f"RpcError during _send_cookies(): code={code}, details={details}, method={method}")
340 raise
342 return res
344 return grpc.unary_unary_rpc_method_handler(
345 function_without_couchers_stuff,
346 request_deserializer=handler.request_deserializer,
347 response_serializer=handler.response_serializer,
348 )
351class MediaInterceptor(grpc.ServerInterceptor):
352 """
353 Extracts an "Authorization: Bearer <hex>" header and calls the
354 is_authorized function. Terminates the call with an HTTP error
355 code if not authorized.
357 Also adds a session to called APIs.
358 """
360 def __init__(self, is_authorized):
361 self._is_authorized = is_authorized
363 def intercept_service(self, continuation, handler_call_details):
364 handler = continuation(handler_call_details)
365 prev_func = handler.unary_unary
366 metadata = dict(handler_call_details.invocation_metadata)
368 token = parse_api_key(metadata)
370 if not token or not self._is_authorized(token):
371 return unauthenticated_handler()
373 def function_without_session(request, grpc_context):
374 with session_scope() as session:
375 return prev_func(request, make_media_context(grpc_context), session)
377 return grpc.unary_unary_rpc_method_handler(
378 function_without_session,
379 request_deserializer=handler.request_deserializer,
380 response_serializer=handler.response_serializer,
381 )
384class OTelInterceptor(grpc.ServerInterceptor):
385 """
386 OpenTelemetry tracing
387 """
389 def __init__(self):
390 self.tracer = trace.get_tracer(__name__)
392 def intercept_service(self, continuation, handler_call_details):
393 handler = continuation(handler_call_details)
394 prev_func = handler.unary_unary
395 method = handler_call_details.method
397 # method is of the form "/org.couchers.api.core.API/GetUser"
398 _, service_name, method_name = method.split("/")
400 headers = dict(handler_call_details.invocation_metadata)
402 def tracing_function(request, context):
403 with self.tracer.start_as_current_span("handler") as rollspan:
404 rollspan.set_attribute("rpc.method_full", method)
405 rollspan.set_attribute("rpc.service", service_name)
406 rollspan.set_attribute("rpc.method", method_name)
408 rollspan.set_attribute("rpc.thread", get_ident())
409 rollspan.set_attribute("rpc.pid", getpid())
411 res = prev_func(request, context)
413 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
414 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
416 return res
418 return grpc.unary_unary_rpc_method_handler(
419 tracing_function,
420 request_deserializer=handler.request_deserializer,
421 response_serializer=handler.response_serializer,
422 )
425class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
426 """
427 If the call resulted in a non-gRPC error, this strips away the error details.
429 It's important to put this first, so that it does not interfere with other interceptors.
430 """
432 def intercept_service(self, continuation, handler_call_details):
433 handler = continuation(handler_call_details)
434 prev_func = handler.unary_unary
436 def sanitizing_function(req, context):
437 try:
438 res = prev_func(req, context)
439 except Exception as e:
440 code = context.code()
441 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
442 if not code:
443 logger.exception(e)
444 logger.info("Probably an unknown error! Sanitizing...")
445 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR)
446 else:
447 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
448 raise e
449 return res
451 return grpc.unary_unary_rpc_method_handler(
452 sanitizing_function,
453 request_deserializer=handler.request_deserializer,
454 response_serializer=handler.response_serializer,
455 )