Coverage for src/couchers/interceptors.py: 90%
234 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-04-25 03:06 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-04-25 03:06 +0000
1import logging
2from copy import deepcopy
3from datetime import timedelta
4from os import getpid
5from threading import get_ident
6from time import perf_counter_ns
7from traceback import format_exception
9import grpc
10import sentry_sdk
11from opentelemetry import trace
12from sqlalchemy.sql import and_, func
14from couchers import errors
15from couchers.db import session_scope
16from couchers.descriptor_pool import get_descriptor_pool
17from couchers.metrics import observe_in_servicer_duration_histogram
18from couchers.models import APICall, User, UserActivity, UserSession
19from couchers.sql import couchers_select as select
20from couchers.utils import (
21 create_lang_cookie,
22 create_session_cookies,
23 now,
24 parse_api_key,
25 parse_session_cookie,
26 parse_ui_lang_cookie,
27 parse_user_id_cookie,
28)
29from proto import annotations_pb2
31logger = logging.getLogger(__name__)
34def _binned_now():
35 return func.date_bin("1 hour", func.now(), "2000-01-01")
38def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent):
39 """
40 Tries to get session and user info corresponding to this token.
42 Also updates the user last active time, token last active time, and increments API call count.
43 """
44 if not token:
45 return None
47 with session_scope() as session:
48 result = session.execute(
49 select(User, UserSession, UserActivity)
50 .join(User, User.id == UserSession.user_id)
51 .outerjoin(
52 UserActivity,
53 and_(
54 UserActivity.user_id == User.id,
55 UserActivity.period == _binned_now(),
56 UserActivity.ip_address == ip_address,
57 UserActivity.user_agent == user_agent,
58 ),
59 )
60 .where(User.is_visible)
61 .where(UserSession.token == token)
62 .where(UserSession.is_valid)
63 .where(UserSession.is_api_key == is_api_key)
64 ).one_or_none()
66 if not result:
67 return None
68 else:
69 user, user_session, user_activity = result
71 # update user last active time if it's been a while
72 if now() - user.last_active > timedelta(minutes=5):
73 user.last_active = func.now()
75 # let's update the token
76 user_session.last_seen = func.now()
77 user_session.api_calls += 1
79 if user_activity:
80 user_activity.api_calls += 1
81 else:
82 session.add(
83 UserActivity(
84 user_id=user.id,
85 period=_binned_now(),
86 ip_address=ip_address,
87 user_agent=user_agent,
88 api_calls=1,
89 )
90 )
92 session.commit()
94 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference
97def abort_handler(message, status_code):
98 def f(request, context):
99 context.abort(status_code, message)
101 return grpc.unary_unary_rpc_method_handler(f)
104def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED):
105 return abort_handler(message, status_code)
108class AuthValidatorInterceptor(grpc.ServerInterceptor):
109 """
110 Extracts a session token from a cookie, and authenticates a user with that.
112 Sets context.user_id and context.token if authenticated, otherwise
113 terminates the call with an UNAUTHENTICATED error code.
114 """
116 def __init__(self):
117 self._pool = get_descriptor_pool()
119 def intercept_service(self, continuation, handler_call_details):
120 method = handler_call_details.method
121 # method is of the form "/org.couchers.api.core.API/GetUser"
122 _, service_name, method_name = method.split("/")
124 try:
125 service_options = self._pool.FindServiceByName(service_name).GetOptions()
126 except KeyError:
127 return abort_handler(
128 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
129 )
131 auth_level = service_options.Extensions[annotations_pb2.auth_level]
133 # if unknown auth level, then it wasn't set and something's wrong
134 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
135 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
137 assert auth_level in [
138 annotations_pb2.AUTH_LEVEL_OPEN,
139 annotations_pb2.AUTH_LEVEL_JAILED,
140 annotations_pb2.AUTH_LEVEL_SECURE,
141 annotations_pb2.AUTH_LEVEL_ADMIN,
142 ]
144 headers = dict(handler_call_details.invocation_metadata)
146 if "cookie" in headers and "authorization" in headers:
147 # for security reasons, only one of "cookie" or "authorization" can be present
148 return unauthenticated_handler('Both "cookie" and "authorization" in request')
149 elif "cookie" in headers:
150 # the session token is passed in cookies, i.e. in the `cookie` header
151 token, is_api_key = parse_session_cookie(headers), False
152 elif "authorization" in headers:
153 # the session token is passed in the `authorization` header
154 token, is_api_key = parse_api_key(headers), True
155 else:
156 # no session found
157 token, is_api_key = None, False
159 ip_address = headers.get("x-couchers-real-ip")
160 user_agent = headers.get("user-agent")
162 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
163 # auth_info is now filled if and only if this is a valid session
164 if not auth_info:
165 token = None
166 is_api_key = False
167 token_expiry = None
168 user_id = None
169 ui_language_preference = None
171 # if no session was found and this isn't an open service, fail
172 if not auth_info:
173 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
174 return unauthenticated_handler()
175 else:
176 # a valid user session was found
177 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info
179 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
180 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
182 # if the user is jailed and this is isn't an open or jailed service, fail
183 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
184 return unauthenticated_handler("Permission denied")
186 handler = continuation(handler_call_details)
187 user_aware_function = handler.unary_unary
189 def user_unaware_function(req, context):
190 context.user_id = user_id
191 context.token = (token, token_expiry)
192 context.is_api_key = is_api_key
193 context.ui_language_preference = ui_language_preference
194 return user_aware_function(req, context)
196 return grpc.unary_unary_rpc_method_handler(
197 user_unaware_function,
198 request_deserializer=handler.request_deserializer,
199 response_serializer=handler.response_serializer,
200 )
203class CookieInterceptor(grpc.ServerInterceptor):
204 """
205 Syncs up the couchers-sesh and couchers-user-id cookies & sets lang cookie
206 """
208 def intercept_service(self, continuation, handler_call_details):
209 headers = dict(handler_call_details.invocation_metadata)
210 cookie_user_id = parse_user_id_cookie(headers)
211 cookie_ui_lang = parse_ui_lang_cookie(headers)
213 handler = continuation(handler_call_details)
214 user_aware_function = handler.unary_unary
216 def user_unaware_function(req, context):
217 res = user_aware_function(req, context)
219 if context.user_id and not context.is_api_key:
220 cookies = []
222 # check the two cookies are in sync & that language preference cookie is correct
223 token, expiry = context.token
224 if cookie_user_id != str(context.user_id):
225 cookies.extend(
226 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)]
227 )
228 if context.ui_language_preference and context.ui_language_preference != cookie_ui_lang:
229 cookies.extend(
230 [("set-cookie", cookie) for cookie in create_lang_cookie(context.ui_language_preference)]
231 )
233 if cookies:
234 try:
235 context.send_initial_metadata(cookies)
236 except ValueError as e:
237 logger.info("Tried to send initial metadata but wasn't allowed to")
239 return res
241 return grpc.unary_unary_rpc_method_handler(
242 user_unaware_function,
243 request_deserializer=handler.request_deserializer,
244 response_serializer=handler.response_serializer,
245 )
248class ManualAuthValidatorInterceptor(grpc.ServerInterceptor):
249 """
250 Extracts an "Authorization: Bearer <hex>" header and calls the
251 is_authorized function. Terminates the call with an HTTP error
252 code if not authorized.
253 """
255 def __init__(self, is_authorized):
256 self._is_authorized = is_authorized
258 def intercept_service(self, continuation, handler_call_details):
259 metadata = dict(handler_call_details.invocation_metadata)
261 token = parse_api_key(metadata)
263 if not token or not self._is_authorized(token):
264 return unauthenticated_handler()
266 return continuation(handler_call_details)
269class OTelInterceptor(grpc.ServerInterceptor):
270 """
271 OpenTelemetry tracing
272 """
274 def __init__(self):
275 self.tracer = trace.get_tracer(__name__)
277 def intercept_service(self, continuation, handler_call_details):
278 handler = continuation(handler_call_details)
279 prev_func = handler.unary_unary
280 method = handler_call_details.method
282 # method is of the form "/org.couchers.api.core.API/GetUser"
283 _, service_name, method_name = method.split("/")
285 headers = dict(handler_call_details.invocation_metadata)
287 def tracing_function(request, context):
288 with self.tracer.start_as_current_span("handler") as rollspan:
289 rollspan.set_attribute("rpc.method_full", method)
290 rollspan.set_attribute("rpc.service", service_name)
291 rollspan.set_attribute("rpc.method", method_name)
293 rollspan.set_attribute("rpc.thread", get_ident())
294 rollspan.set_attribute("rpc.pid", getpid())
296 res = prev_func(request, context)
298 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
299 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
301 return res
303 return grpc.unary_unary_rpc_method_handler(
304 tracing_function,
305 request_deserializer=handler.request_deserializer,
306 response_serializer=handler.response_serializer,
307 )
310class SessionInterceptor(grpc.ServerInterceptor):
311 """
312 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the
313 function signature by adding another argument.
314 """
316 def intercept_service(self, continuation, handler_call_details):
317 handler = continuation(handler_call_details)
318 prev_func = handler.unary_unary
320 def function_without_session(request, context):
321 with session_scope() as session:
322 return prev_func(request, context, session)
324 return grpc.unary_unary_rpc_method_handler(
325 function_without_session,
326 request_deserializer=handler.request_deserializer,
327 response_serializer=handler.response_serializer,
328 )
331class TracingInterceptor(grpc.ServerInterceptor):
332 """
333 Measures and logs the time it takes to service each incoming call.
334 """
336 def _sanitized_bytes(self, proto):
337 """
338 Remove fields marked sensitive and return serialized bytes
339 """
340 if not proto:
341 return None
343 new_proto = deepcopy(proto)
345 def _sanitize_message(message):
346 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
347 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
348 message.ClearField(name)
349 if descriptor.message_type:
350 submessage = getattr(message, name)
351 if not submessage:
352 continue
353 if descriptor.label == descriptor.LABEL_REPEATED:
354 for msg in submessage:
355 _sanitize_message(msg)
356 else:
357 _sanitize_message(submessage)
359 _sanitize_message(new_proto)
361 return new_proto.SerializeToString()
363 def _store_log(
364 self,
365 method,
366 status_code,
367 duration,
368 user_id,
369 is_api_key,
370 request,
371 response,
372 traceback,
373 perf_report,
374 ip_address,
375 user_agent,
376 ):
377 req_bytes = self._sanitized_bytes(request)
378 res_bytes = self._sanitized_bytes(response)
379 with session_scope() as session:
380 response_truncated = False
381 truncate_res_bytes_length = 16 * 1024 # 16 kB
382 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
383 res_bytes = res_bytes[:truncate_res_bytes_length]
384 response_truncated = True
385 session.add(
386 APICall(
387 is_api_key=is_api_key,
388 method=method,
389 status_code=status_code,
390 duration=duration,
391 user_id=user_id,
392 request=req_bytes,
393 response=res_bytes,
394 response_truncated=response_truncated,
395 traceback=traceback,
396 perf_report=perf_report,
397 ip_address=ip_address,
398 user_agent=user_agent,
399 )
400 )
401 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
403 def intercept_service(self, continuation, handler_call_details):
404 handler = continuation(handler_call_details)
405 prev_func = handler.unary_unary
406 method = handler_call_details.method
408 headers = dict(handler_call_details.invocation_metadata)
409 ip_address = headers.get("x-couchers-real-ip")
410 user_agent = headers.get("user-agent")
412 def tracing_function(request, context):
413 try:
414 start = perf_counter_ns()
415 res = prev_func(request, context)
416 finished = perf_counter_ns()
417 duration = (finished - start) / 1e6 # ms
418 user_id = getattr(context, "user_id", None)
419 is_api_key = getattr(context, "is_api_key", None)
420 self._store_log(
421 method, None, duration, user_id, is_api_key, request, res, None, None, ip_address, user_agent
422 )
423 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000)
424 except Exception as e:
425 finished = perf_counter_ns()
426 duration = (finished - start) / 1e6 # ms
427 code = getattr(context.code(), "name", None)
428 traceback = "".join(format_exception(type(e), e, e.__traceback__))
429 user_id = getattr(context, "user_id", None)
430 is_api_key = getattr(context, "is_api_key", None)
431 self._store_log(
432 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent
433 )
434 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000)
436 if not code:
437 sentry_sdk.set_tag("context", "servicer")
438 sentry_sdk.set_tag("method", method)
439 sentry_sdk.capture_exception(e)
441 raise e
442 return res
444 return grpc.unary_unary_rpc_method_handler(
445 tracing_function,
446 request_deserializer=handler.request_deserializer,
447 response_serializer=handler.response_serializer,
448 )
451class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
452 """
453 If the call resulted in a non-gRPC error, this strips away the error details.
455 It's important to put this first, so that it does not interfere with other interceptors.
456 """
458 def intercept_service(self, continuation, handler_call_details):
459 handler = continuation(handler_call_details)
460 prev_func = handler.unary_unary
462 def sanitizing_function(req, context):
463 try:
464 res = prev_func(req, context)
465 except Exception as e:
466 code = context.code()
467 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
468 if not code:
469 logger.exception(e)
470 logger.info("Probably an unknown error! Sanitizing...")
471 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR)
472 else:
473 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
474 raise e
475 return res
477 return grpc.unary_unary_rpc_method_handler(
478 sanitizing_function,
479 request_deserializer=handler.request_deserializer,
480 response_serializer=handler.response_serializer,
481 )