Coverage for src/couchers/interceptors.py: 89%
225 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-22 06:42 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-22 06:42 +0000
1import logging
2from copy import deepcopy
3from datetime import timedelta
4from os import getpid
5from threading import get_ident
6from time import perf_counter_ns
7from traceback import format_exception
9import grpc
10import sentry_sdk
11from opentelemetry import trace
12from sqlalchemy.sql import and_, func
14from couchers import errors
15from couchers.db import session_scope
16from couchers.descriptor_pool import get_descriptor_pool
17from couchers.metrics import observe_in_servicer_duration_histogram
18from couchers.models import APICall, User, UserActivity, UserSession
19from couchers.sql import couchers_select as select
20from couchers.utils import create_session_cookies, now, parse_api_key, parse_session_cookie, parse_user_id_cookie
21from proto import annotations_pb2
23logger = logging.getLogger(__name__)
26def _binned_now():
27 return func.date_bin("1 hour", func.now(), "2000-01-01")
30def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent):
31 """
32 Tries to get session and user info corresponding to this token.
34 Also updates the user last active time, token last active time, and increments API call count.
35 """
36 if not token:
37 return None
39 with session_scope() as session:
40 result = session.execute(
41 select(User, UserSession, UserActivity)
42 .join(User, User.id == UserSession.user_id)
43 .outerjoin(
44 UserActivity,
45 and_(
46 UserActivity.user_id == User.id,
47 UserActivity.period == _binned_now(),
48 UserActivity.ip_address == ip_address,
49 UserActivity.user_agent == user_agent,
50 ),
51 )
52 .where(User.is_visible)
53 .where(UserSession.token == token)
54 .where(UserSession.is_valid)
55 .where(UserSession.is_api_key == is_api_key)
56 ).one_or_none()
58 if not result:
59 return None
60 else:
61 user, user_session, user_activity = result
63 # update user last active time if it's been a while
64 if now() - user.last_active > timedelta(minutes=5):
65 user.last_active = func.now()
67 # let's update the token
68 user_session.last_seen = func.now()
69 user_session.api_calls += 1
71 if user_activity:
72 user_activity.api_calls += 1
73 else:
74 session.add(
75 UserActivity(
76 user_id=user.id,
77 period=_binned_now(),
78 ip_address=ip_address,
79 user_agent=user_agent,
80 api_calls=1,
81 )
82 )
84 session.commit()
86 return user.id, user.is_jailed, user.is_superuser, user_session.expiry
89def abort_handler(message, status_code):
90 def f(request, context):
91 context.abort(status_code, message)
93 return grpc.unary_unary_rpc_method_handler(f)
96def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED):
97 return abort_handler(message, status_code)
100class AuthValidatorInterceptor(grpc.ServerInterceptor):
101 """
102 Extracts a session token from a cookie, and authenticates a user with that.
104 Sets context.user_id and context.token if authenticated, otherwise
105 terminates the call with an UNAUTHENTICATED error code.
106 """
108 def __init__(self):
109 self._pool = get_descriptor_pool()
111 def intercept_service(self, continuation, handler_call_details):
112 method = handler_call_details.method
113 # method is of the form "/org.couchers.api.core.API/GetUser"
114 _, service_name, method_name = method.split("/")
116 try:
117 service_options = self._pool.FindServiceByName(service_name).GetOptions()
118 except KeyError:
119 return abort_handler(
120 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
121 )
123 auth_level = service_options.Extensions[annotations_pb2.auth_level]
125 # if unknown auth level, then it wasn't set and something's wrong
126 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
127 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
129 assert auth_level in [
130 annotations_pb2.AUTH_LEVEL_OPEN,
131 annotations_pb2.AUTH_LEVEL_JAILED,
132 annotations_pb2.AUTH_LEVEL_SECURE,
133 annotations_pb2.AUTH_LEVEL_ADMIN,
134 ]
136 headers = dict(handler_call_details.invocation_metadata)
138 if "cookie" in headers and "authorization" in headers:
139 # for security reasons, only one of "cookie" or "authorization" can be present
140 return unauthenticated_handler('Both "cookie" and "authorization" in request')
141 elif "cookie" in headers:
142 # the session token is passed in cookies, i.e. in the `cookie` header
143 token, is_api_key = parse_session_cookie(headers), False
144 elif "authorization" in headers:
145 # the session token is passed in the `authorization` header
146 token, is_api_key = parse_api_key(headers), True
147 else:
148 # no session found
149 token, is_api_key = None, False
151 ip_address = headers.get("x-couchers-real-ip")
152 user_agent = headers.get("user-agent")
154 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
155 # auth_info is now filled if and only if this is a valid session
156 if not auth_info:
157 token = None
158 is_api_key = False
159 token_expiry = None
160 user_id = None
162 # if no session was found and this isn't an open service, fail
163 if not auth_info:
164 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
165 return unauthenticated_handler()
166 else:
167 # a valid user session was found
168 user_id, is_jailed, is_superuser, token_expiry = auth_info
170 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
171 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
173 # if the user is jailed and this is isn't an open or jailed service, fail
174 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
175 return unauthenticated_handler("Permission denied")
177 handler = continuation(handler_call_details)
178 user_aware_function = handler.unary_unary
180 def user_unaware_function(req, context):
181 context.user_id = user_id
182 context.token = (token, token_expiry)
183 context.is_api_key = is_api_key
184 return user_aware_function(req, context)
186 return grpc.unary_unary_rpc_method_handler(
187 user_unaware_function,
188 request_deserializer=handler.request_deserializer,
189 response_serializer=handler.response_serializer,
190 )
193class CookieInterceptor(grpc.ServerInterceptor):
194 """
195 Syncs up the couchers-sesh and couchers-user-id cookies
196 """
198 def intercept_service(self, continuation, handler_call_details):
199 headers = dict(handler_call_details.invocation_metadata)
200 cookie_user_id = parse_user_id_cookie(headers)
202 handler = continuation(handler_call_details)
203 user_aware_function = handler.unary_unary
205 def user_unaware_function(req, context):
206 res = user_aware_function(req, context)
208 # check the two cookies are in sync
209 if context.user_id and not context.is_api_key and cookie_user_id != str(context.user_id):
210 try:
211 token, expiry = context.token
212 context.send_initial_metadata(
213 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)]
214 )
215 except ValueError as e:
216 logger.info("Tried to send initial metadata but wasn't allowed to")
218 return res
220 return grpc.unary_unary_rpc_method_handler(
221 user_unaware_function,
222 request_deserializer=handler.request_deserializer,
223 response_serializer=handler.response_serializer,
224 )
227class ManualAuthValidatorInterceptor(grpc.ServerInterceptor):
228 """
229 Extracts an "Authorization: Bearer <hex>" header and calls the
230 is_authorized function. Terminates the call with an HTTP error
231 code if not authorized.
232 """
234 def __init__(self, is_authorized):
235 self._is_authorized = is_authorized
237 def intercept_service(self, continuation, handler_call_details):
238 metadata = dict(handler_call_details.invocation_metadata)
240 token = parse_api_key(metadata)
242 if not token or not self._is_authorized(token):
243 return unauthenticated_handler()
245 return continuation(handler_call_details)
248class OTelInterceptor(grpc.ServerInterceptor):
249 """
250 OpenTelemetry tracing
251 """
253 def __init__(self):
254 self.tracer = trace.get_tracer(__name__)
256 def intercept_service(self, continuation, handler_call_details):
257 handler = continuation(handler_call_details)
258 prev_func = handler.unary_unary
259 method = handler_call_details.method
261 # method is of the form "/org.couchers.api.core.API/GetUser"
262 _, service_name, method_name = method.split("/")
264 headers = dict(handler_call_details.invocation_metadata)
266 def tracing_function(request, context):
267 with self.tracer.start_as_current_span("handler") as rollspan:
268 rollspan.set_attribute("rpc.method_full", method)
269 rollspan.set_attribute("rpc.service", service_name)
270 rollspan.set_attribute("rpc.method", method_name)
272 rollspan.set_attribute("rpc.thread", get_ident())
273 rollspan.set_attribute("rpc.pid", getpid())
275 res = prev_func(request, context)
277 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
278 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
280 return res
282 return grpc.unary_unary_rpc_method_handler(
283 tracing_function,
284 request_deserializer=handler.request_deserializer,
285 response_serializer=handler.response_serializer,
286 )
289class SessionInterceptor(grpc.ServerInterceptor):
290 """
291 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the
292 function signature by adding another argument.
293 """
295 def intercept_service(self, continuation, handler_call_details):
296 handler = continuation(handler_call_details)
297 prev_func = handler.unary_unary
299 def function_without_session(request, context):
300 with session_scope() as session:
301 return prev_func(request, context, session)
303 return grpc.unary_unary_rpc_method_handler(
304 function_without_session,
305 request_deserializer=handler.request_deserializer,
306 response_serializer=handler.response_serializer,
307 )
310class TracingInterceptor(grpc.ServerInterceptor):
311 """
312 Measures and logs the time it takes to service each incoming call.
313 """
315 def _sanitized_bytes(self, proto):
316 """
317 Remove fields marked sensitive and return serialized bytes
318 """
319 if not proto:
320 return None
322 new_proto = deepcopy(proto)
324 def _sanitize_message(message):
325 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
326 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
327 message.ClearField(name)
328 if descriptor.message_type:
329 submessage = getattr(message, name)
330 if not submessage:
331 continue
332 if descriptor.label == descriptor.LABEL_REPEATED:
333 for msg in submessage:
334 _sanitize_message(msg)
335 else:
336 _sanitize_message(submessage)
338 _sanitize_message(new_proto)
340 return new_proto.SerializeToString()
342 def _store_log(
343 self,
344 method,
345 status_code,
346 duration,
347 user_id,
348 is_api_key,
349 request,
350 response,
351 traceback,
352 perf_report,
353 ip_address,
354 user_agent,
355 ):
356 req_bytes = self._sanitized_bytes(request)
357 res_bytes = self._sanitized_bytes(response)
358 with session_scope() as session:
359 response_truncated = False
360 truncate_res_bytes_length = 16 * 1024 # 16 kB
361 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
362 res_bytes = res_bytes[:truncate_res_bytes_length]
363 response_truncated = True
364 session.add(
365 APICall(
366 is_api_key=is_api_key,
367 method=method,
368 status_code=status_code,
369 duration=duration,
370 user_id=user_id,
371 request=req_bytes,
372 response=res_bytes,
373 response_truncated=response_truncated,
374 traceback=traceback,
375 perf_report=perf_report,
376 ip_address=ip_address,
377 user_agent=user_agent,
378 )
379 )
380 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
382 def intercept_service(self, continuation, handler_call_details):
383 handler = continuation(handler_call_details)
384 prev_func = handler.unary_unary
385 method = handler_call_details.method
387 headers = dict(handler_call_details.invocation_metadata)
388 ip_address = headers.get("x-couchers-real-ip")
389 user_agent = headers.get("user-agent")
391 def tracing_function(request, context):
392 try:
393 start = perf_counter_ns()
394 res = prev_func(request, context)
395 finished = perf_counter_ns()
396 duration = (finished - start) / 1e6 # ms
397 user_id = getattr(context, "user_id", None)
398 is_api_key = getattr(context, "is_api_key", None)
399 self._store_log(
400 method, None, duration, user_id, is_api_key, request, res, None, None, ip_address, user_agent
401 )
402 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000)
403 except Exception as e:
404 finished = perf_counter_ns()
405 duration = (finished - start) / 1e6 # ms
406 code = getattr(context.code(), "name", None)
407 traceback = "".join(format_exception(type(e), e, e.__traceback__))
408 user_id = getattr(context, "user_id", None)
409 is_api_key = getattr(context, "is_api_key", None)
410 self._store_log(
411 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent
412 )
413 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000)
415 if not code:
416 sentry_sdk.set_tag("context", "servicer")
417 sentry_sdk.set_tag("method", method)
418 sentry_sdk.capture_exception(e)
420 raise e
421 return res
423 return grpc.unary_unary_rpc_method_handler(
424 tracing_function,
425 request_deserializer=handler.request_deserializer,
426 response_serializer=handler.response_serializer,
427 )
430class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
431 """
432 If the call resulted in a non-gRPC error, this strips away the error details.
434 It's important to put this first, so that it does not interfere with other interceptors.
435 """
437 def intercept_service(self, continuation, handler_call_details):
438 handler = continuation(handler_call_details)
439 prev_func = handler.unary_unary
441 def sanitizing_function(req, context):
442 try:
443 res = prev_func(req, context)
444 except Exception as e:
445 code = context.code()
446 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
447 if not code:
448 logger.exception(e)
449 logger.info("Probably an unknown error! Sanitizing...")
450 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR)
451 else:
452 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
453 raise e
454 return res
456 return grpc.unary_unary_rpc_method_handler(
457 sanitizing_function,
458 request_deserializer=handler.request_deserializer,
459 response_serializer=handler.response_serializer,
460 )