Coverage for src/couchers/interceptors.py: 90%
199 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-28 14:55 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-28 14:55 +0000
1import logging
2from copy import deepcopy
3from datetime import timedelta
4from os import getpid
5from threading import get_ident
6from time import perf_counter_ns
7from traceback import format_exception
9import grpc
10import sentry_sdk
11from opentelemetry import trace
12from sqlalchemy.sql import and_, func
14from couchers import errors
15from couchers.context import CouchersContext, make_interactive_user_context, make_media_context
16from couchers.db import session_scope
17from couchers.descriptor_pool import get_descriptor_pool
18from couchers.metrics import observe_in_servicer_duration_histogram
19from couchers.models import APICall, User, UserActivity, UserSession
20from couchers.sql import couchers_select as select
21from couchers.utils import (
22 create_lang_cookie,
23 create_session_cookies,
24 now,
25 parse_api_key,
26 parse_session_cookie,
27 parse_ui_lang_cookie,
28 parse_user_id_cookie,
29)
30from proto import annotations_pb2
32logger = logging.getLogger(__name__)
35def _binned_now():
36 return func.date_bin("1 hour", func.now(), "2000-01-01")
39def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent):
40 """
41 Tries to get session and user info corresponding to this token.
43 Also updates the user last active time, token last active time, and increments API call count.
44 """
45 if not token:
46 return None
48 with session_scope() as session:
49 result = session.execute(
50 select(User, UserSession, UserActivity)
51 .join(User, User.id == UserSession.user_id)
52 .outerjoin(
53 UserActivity,
54 and_(
55 UserActivity.user_id == User.id,
56 UserActivity.period == _binned_now(),
57 UserActivity.ip_address == ip_address,
58 UserActivity.user_agent == user_agent,
59 ),
60 )
61 .where(User.is_visible)
62 .where(UserSession.token == token)
63 .where(UserSession.is_valid)
64 .where(UserSession.is_api_key == is_api_key)
65 ).one_or_none()
67 if not result:
68 return None
69 else:
70 user, user_session, user_activity = result
72 # update user last active time if it's been a while
73 if now() - user.last_active > timedelta(minutes=5):
74 user.last_active = func.now()
76 # let's update the token
77 user_session.last_seen = func.now()
78 user_session.api_calls += 1
80 if user_activity:
81 user_activity.api_calls += 1
82 else:
83 session.add(
84 UserActivity(
85 user_id=user.id,
86 period=_binned_now(),
87 ip_address=ip_address,
88 user_agent=user_agent,
89 api_calls=1,
90 )
91 )
93 session.commit()
95 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference
98def abort_handler(message, status_code):
99 def f(request, context):
100 context.abort(status_code, message)
102 return grpc.unary_unary_rpc_method_handler(f)
105def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED):
106 return abort_handler(message, status_code)
109def _sanitized_bytes(proto):
110 """
111 Remove fields marked sensitive and return serialized bytes
112 """
113 if not proto:
114 return None
116 new_proto = deepcopy(proto)
118 def _sanitize_message(message):
119 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
120 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
121 message.ClearField(name)
122 if descriptor.message_type:
123 submessage = getattr(message, name)
124 if not submessage:
125 continue
126 if descriptor.label == descriptor.LABEL_REPEATED:
127 for msg in submessage:
128 _sanitize_message(msg)
129 else:
130 _sanitize_message(submessage)
132 _sanitize_message(new_proto)
134 return new_proto.SerializeToString()
137def _store_log(
138 *,
139 method,
140 status_code,
141 duration,
142 user_id,
143 is_api_key,
144 request,
145 response,
146 traceback,
147 perf_report,
148 ip_address,
149 user_agent,
150):
151 req_bytes = _sanitized_bytes(request)
152 res_bytes = _sanitized_bytes(response)
153 with session_scope() as session:
154 response_truncated = False
155 truncate_res_bytes_length = 16 * 1024 # 16 kB
156 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
157 res_bytes = res_bytes[:truncate_res_bytes_length]
158 response_truncated = True
159 session.add(
160 APICall(
161 is_api_key=is_api_key,
162 method=method,
163 status_code=status_code,
164 duration=duration,
165 user_id=user_id,
166 request=req_bytes,
167 response=res_bytes,
168 response_truncated=response_truncated,
169 traceback=traceback,
170 perf_report=perf_report,
171 ip_address=ip_address,
172 user_agent=user_agent,
173 )
174 )
175 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
178class CouchersMiddlewareInterceptor(grpc.ServerInterceptor):
179 """
180 1. Does auth: extracts a session token from a cookie, and authenticates a user with that.
182 Sets context.user_id and context.token if authenticated, otherwise
183 terminates the call with an UNAUTHENTICATED error code.
185 2. Makes sure cookies are in sync.
187 3. Injects a session to get a database transaction.
189 4. Measures and logs the time it takes to service each incoming call.
190 """
192 def __init__(self):
193 self._pool = get_descriptor_pool()
195 def intercept_service(self, continuation, handler_call_details):
196 start = perf_counter_ns()
198 method = handler_call_details.method
199 # method is of the form "/org.couchers.api.core.API/GetUser"
200 _, service_name, method_name = method.split("/")
202 try:
203 service_options = self._pool.FindServiceByName(service_name).GetOptions()
204 except KeyError:
205 return abort_handler(
206 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
207 )
209 auth_level = service_options.Extensions[annotations_pb2.auth_level]
211 # if unknown auth level, then it wasn't set and something's wrong
212 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
213 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
215 assert auth_level in [
216 annotations_pb2.AUTH_LEVEL_OPEN,
217 annotations_pb2.AUTH_LEVEL_JAILED,
218 annotations_pb2.AUTH_LEVEL_SECURE,
219 annotations_pb2.AUTH_LEVEL_ADMIN,
220 ]
222 headers = dict(handler_call_details.invocation_metadata)
224 if "cookie" in headers and "authorization" in headers:
225 # for security reasons, only one of "cookie" or "authorization" can be present
226 return unauthenticated_handler('Both "cookie" and "authorization" in request')
227 elif "cookie" in headers:
228 # the session token is passed in cookies, i.e. in the `cookie` header
229 token, is_api_key = parse_session_cookie(headers), False
230 elif "authorization" in headers:
231 # the session token is passed in the `authorization` header
232 token, is_api_key = parse_api_key(headers), True
233 else:
234 # no session found
235 token, is_api_key = None, False
237 ip_address = headers.get("x-couchers-real-ip")
238 user_agent = headers.get("user-agent")
240 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
241 # auth_info is now filled if and only if this is a valid session
242 if not auth_info:
243 token = None
244 is_api_key = False
245 token_expiry = None
246 user_id = None
247 ui_language_preference = None
249 # if no session was found and this isn't an open service, fail
250 if not auth_info:
251 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
252 return unauthenticated_handler()
253 else:
254 # a valid user session was found
255 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info
257 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
258 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
260 # if the user is jailed and this is isn't an open or jailed service, fail
261 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
262 return unauthenticated_handler("Permission denied")
264 handler = continuation(handler_call_details)
265 prev_function = handler.unary_unary
267 def function_without_couchers_stuff(req, grpc_context):
268 couchers_context: CouchersContext = make_interactive_user_context(
269 grpc_context=grpc_context,
270 user_id=user_id,
271 is_api_key=is_api_key,
272 token=token,
273 ui_language_preference=ui_language_preference,
274 )
275 with session_scope() as session:
276 try:
277 res = prev_function(req, couchers_context, session)
278 finished = perf_counter_ns()
279 duration = (finished - start) / 1e6 # ms
280 _store_log(
281 method=method,
282 status_code=None,
283 duration=duration,
284 user_id=couchers_context._user_id,
285 is_api_key=couchers_context._is_api_key,
286 request=req,
287 response=res,
288 traceback=None,
289 perf_report=None,
290 ip_address=ip_address,
291 user_agent=user_agent,
292 )
293 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000)
294 except Exception as e:
295 finished = perf_counter_ns()
296 duration = (finished - start) / 1e6 # ms
297 code = getattr(couchers_context._grpc_context.code(), "name", None)
298 traceback = "".join(format_exception(type(e), e, e.__traceback__))
299 _store_log(
300 method=method,
301 status_code=code,
302 duration=duration,
303 user_id=couchers_context._user_id,
304 is_api_key=couchers_context._is_api_key,
305 request=req,
306 response=None,
307 traceback=traceback,
308 perf_report=None,
309 ip_address=ip_address,
310 user_agent=user_agent,
311 )
312 observe_in_servicer_duration_histogram(
313 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000
314 )
316 if not code:
317 sentry_sdk.set_tag("context", "servicer")
318 sentry_sdk.set_tag("method", method)
319 sentry_sdk.capture_exception(e)
321 raise e
323 if user_id and not is_api_key:
324 cookies = []
326 # check the two cookies are in sync & that language preference cookie is correct
327 if parse_user_id_cookie(headers) != str(user_id):
328 couchers_context.set_cookies(create_session_cookies(token, user_id, token_expiry))
329 if ui_language_preference and ui_language_preference != parse_ui_lang_cookie(headers):
330 couchers_context.set_cookies(create_lang_cookie(ui_language_preference))
332 couchers_context._send_cookies()
334 return res
336 return grpc.unary_unary_rpc_method_handler(
337 function_without_couchers_stuff,
338 request_deserializer=handler.request_deserializer,
339 response_serializer=handler.response_serializer,
340 )
343class MediaInterceptor(grpc.ServerInterceptor):
344 """
345 Extracts an "Authorization: Bearer <hex>" header and calls the
346 is_authorized function. Terminates the call with an HTTP error
347 code if not authorized.
349 Also adds a session to called APIs.
350 """
352 def __init__(self, is_authorized):
353 self._is_authorized = is_authorized
355 def intercept_service(self, continuation, handler_call_details):
356 handler = continuation(handler_call_details)
357 prev_func = handler.unary_unary
358 metadata = dict(handler_call_details.invocation_metadata)
360 token = parse_api_key(metadata)
362 if not token or not self._is_authorized(token):
363 return unauthenticated_handler()
365 def function_without_session(request, grpc_context):
366 with session_scope() as session:
367 return prev_func(request, make_media_context(grpc_context), session)
369 return grpc.unary_unary_rpc_method_handler(
370 function_without_session,
371 request_deserializer=handler.request_deserializer,
372 response_serializer=handler.response_serializer,
373 )
376class OTelInterceptor(grpc.ServerInterceptor):
377 """
378 OpenTelemetry tracing
379 """
381 def __init__(self):
382 self.tracer = trace.get_tracer(__name__)
384 def intercept_service(self, continuation, handler_call_details):
385 handler = continuation(handler_call_details)
386 prev_func = handler.unary_unary
387 method = handler_call_details.method
389 # method is of the form "/org.couchers.api.core.API/GetUser"
390 _, service_name, method_name = method.split("/")
392 headers = dict(handler_call_details.invocation_metadata)
394 def tracing_function(request, context):
395 with self.tracer.start_as_current_span("handler") as rollspan:
396 rollspan.set_attribute("rpc.method_full", method)
397 rollspan.set_attribute("rpc.service", service_name)
398 rollspan.set_attribute("rpc.method", method_name)
400 rollspan.set_attribute("rpc.thread", get_ident())
401 rollspan.set_attribute("rpc.pid", getpid())
403 res = prev_func(request, context)
405 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
406 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
408 return res
410 return grpc.unary_unary_rpc_method_handler(
411 tracing_function,
412 request_deserializer=handler.request_deserializer,
413 response_serializer=handler.response_serializer,
414 )
417class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
418 """
419 If the call resulted in a non-gRPC error, this strips away the error details.
421 It's important to put this first, so that it does not interfere with other interceptors.
422 """
424 def intercept_service(self, continuation, handler_call_details):
425 handler = continuation(handler_call_details)
426 prev_func = handler.unary_unary
428 def sanitizing_function(req, context):
429 try:
430 res = prev_func(req, context)
431 except Exception as e:
432 code = context.code()
433 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
434 if not code:
435 logger.exception(e)
436 logger.info("Probably an unknown error! Sanitizing...")
437 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR)
438 else:
439 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
440 raise e
441 return res
443 return grpc.unary_unary_rpc_method_handler(
444 sanitizing_function,
445 request_deserializer=handler.request_deserializer,
446 response_serializer=handler.response_serializer,
447 )