Coverage for src/couchers/interceptors.py: 89%
227 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-04 02:51 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-04 02:51 +0000
1import logging
2from copy import deepcopy
3from datetime import timedelta
4from os import getpid
5from threading import get_ident
6from time import perf_counter_ns
7from traceback import format_exception
9import grpc
10import sentry_sdk
11from opentelemetry import trace
12from sqlalchemy.sql import and_, func
14from couchers import errors
15from couchers.db import session_scope
16from couchers.descriptor_pool import get_descriptor_pool
17from couchers.metrics import observe_in_servicer_duration_histogram
18from couchers.models import APICall, User, UserActivity, UserSession
19from couchers.profiler import CouchersProfiler
20from couchers.sql import couchers_select as select
21from couchers.utils import create_session_cookies, now, parse_api_key, parse_session_cookie, parse_user_id_cookie
22from proto import annotations_pb2
24logger = logging.getLogger(__name__)
27def _binned_now():
28 return func.date_bin("1 hour", func.now(), "2000-01-01")
31def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent):
32 """
33 Tries to get session and user info corresponding to this token.
35 Also updates the user last active time, token last active time, and increments API call count.
36 """
37 if not token:
38 return None
40 with session_scope() as session:
41 result = session.execute(
42 select(User, UserSession, UserActivity)
43 .join(User, User.id == UserSession.user_id)
44 .outerjoin(
45 UserActivity,
46 and_(
47 UserActivity.user_id == User.id,
48 UserActivity.period == _binned_now(),
49 UserActivity.ip_address == ip_address,
50 UserActivity.user_agent == user_agent,
51 ),
52 )
53 .where(User.is_visible)
54 .where(UserSession.token == token)
55 .where(UserSession.is_valid)
56 .where(UserSession.is_api_key == is_api_key)
57 ).one_or_none()
59 if not result:
60 return None
61 else:
62 user, user_session, user_activity = result
64 # update user last active time if it's been a while
65 if now() - user.last_active > timedelta(minutes=5):
66 user.last_active = func.now()
68 # let's update the token
69 user_session.last_seen = func.now()
70 user_session.api_calls += 1
72 if user_activity:
73 user_activity.api_calls += 1
74 else:
75 session.add(
76 UserActivity(
77 user_id=user.id,
78 period=_binned_now(),
79 ip_address=ip_address,
80 user_agent=user_agent,
81 api_calls=1,
82 )
83 )
85 session.commit()
87 return user.id, user.is_jailed, user.is_superuser, user_session.expiry
90def abort_handler(message, status_code):
91 def f(request, context):
92 context.abort(status_code, message)
94 return grpc.unary_unary_rpc_method_handler(f)
97def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED):
98 return abort_handler(message, status_code)
101class AuthValidatorInterceptor(grpc.ServerInterceptor):
102 """
103 Extracts a session token from a cookie, and authenticates a user with that.
105 Sets context.user_id and context.token if authenticated, otherwise
106 terminates the call with an UNAUTHENTICATED error code.
107 """
109 def __init__(self):
110 self._pool = get_descriptor_pool()
112 def intercept_service(self, continuation, handler_call_details):
113 method = handler_call_details.method
114 # method is of the form "/org.couchers.api.core.API/GetUser"
115 _, service_name, method_name = method.split("/")
117 try:
118 service_options = self._pool.FindServiceByName(service_name).GetOptions()
119 except KeyError:
120 return abort_handler(
121 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED
122 )
124 auth_level = service_options.Extensions[annotations_pb2.auth_level]
126 # if unknown auth level, then it wasn't set and something's wrong
127 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN:
128 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL)
130 assert auth_level in [
131 annotations_pb2.AUTH_LEVEL_OPEN,
132 annotations_pb2.AUTH_LEVEL_JAILED,
133 annotations_pb2.AUTH_LEVEL_SECURE,
134 annotations_pb2.AUTH_LEVEL_ADMIN,
135 ]
137 headers = dict(handler_call_details.invocation_metadata)
139 if "cookie" in headers and "authorization" in headers:
140 # for security reasons, only one of "cookie" or "authorization" can be present
141 return unauthenticated_handler('Both "cookie" and "authorization" in request')
142 elif "cookie" in headers:
143 # the session token is passed in cookies, i.e. in the `cookie` header
144 token, is_api_key = parse_session_cookie(headers), False
145 elif "authorization" in headers:
146 # the session token is passed in the `authorization` header
147 token, is_api_key = parse_api_key(headers), True
148 else:
149 # no session found
150 token, is_api_key = None, False
152 ip_address = headers.get("x-couchers-real-ip")
153 user_agent = headers.get("user-agent")
155 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent)
156 # auth_info is now filled if and only if this is a valid session
157 if not auth_info:
158 token = None
159 is_api_key = False
160 token_expiry = None
161 user_id = None
163 # if no session was found and this isn't an open service, fail
164 if not auth_info:
165 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN:
166 return unauthenticated_handler()
167 else:
168 # a valid user session was found
169 user_id, is_jailed, is_superuser, token_expiry = auth_info
171 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser:
172 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED)
174 # if the user is jailed and this is isn't an open or jailed service, fail
175 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]:
176 return unauthenticated_handler("Permission denied")
178 handler = continuation(handler_call_details)
179 user_aware_function = handler.unary_unary
181 def user_unaware_function(req, context):
182 context.user_id = user_id
183 context.token = (token, token_expiry)
184 context.is_api_key = is_api_key
185 return user_aware_function(req, context)
187 return grpc.unary_unary_rpc_method_handler(
188 user_unaware_function,
189 request_deserializer=handler.request_deserializer,
190 response_serializer=handler.response_serializer,
191 )
194class CookieInterceptor(grpc.ServerInterceptor):
195 """
196 Syncs up the couchers-sesh and couchers-user-id cookies
197 """
199 def intercept_service(self, continuation, handler_call_details):
200 headers = dict(handler_call_details.invocation_metadata)
201 cookie_user_id = parse_user_id_cookie(headers)
203 handler = continuation(handler_call_details)
204 user_aware_function = handler.unary_unary
206 def user_unaware_function(req, context):
207 res = user_aware_function(req, context)
209 # check the two cookies are in sync
210 if context.user_id and not context.is_api_key and cookie_user_id != str(context.user_id):
211 try:
212 token, expiry = context.token
213 context.send_initial_metadata(
214 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)]
215 )
216 except ValueError as e:
217 logger.info("Tried to send initial metadata but wasn't allowed to")
219 return res
221 return grpc.unary_unary_rpc_method_handler(
222 user_unaware_function,
223 request_deserializer=handler.request_deserializer,
224 response_serializer=handler.response_serializer,
225 )
228class ManualAuthValidatorInterceptor(grpc.ServerInterceptor):
229 """
230 Extracts an "Authorization: Bearer <hex>" header and calls the
231 is_authorized function. Terminates the call with an HTTP error
232 code if not authorized.
233 """
235 def __init__(self, is_authorized):
236 self._is_authorized = is_authorized
238 def intercept_service(self, continuation, handler_call_details):
239 metadata = dict(handler_call_details.invocation_metadata)
241 token = parse_api_key(metadata)
243 if not token or not self._is_authorized(token):
244 return unauthenticated_handler()
246 return continuation(handler_call_details)
249class OTelInterceptor(grpc.ServerInterceptor):
250 """
251 OpenTelemetry tracing
252 """
254 def __init__(self):
255 self.tracer = trace.get_tracer(__name__)
257 def intercept_service(self, continuation, handler_call_details):
258 handler = continuation(handler_call_details)
259 prev_func = handler.unary_unary
260 method = handler_call_details.method
262 # method is of the form "/org.couchers.api.core.API/GetUser"
263 _, service_name, method_name = method.split("/")
265 headers = dict(handler_call_details.invocation_metadata)
267 def tracing_function(request, context):
268 with self.tracer.start_as_current_span("handler") as rollspan:
269 rollspan.set_attribute("rpc.method_full", method)
270 rollspan.set_attribute("rpc.service", service_name)
271 rollspan.set_attribute("rpc.method", method_name)
273 rollspan.set_attribute("rpc.thread", get_ident())
274 rollspan.set_attribute("rpc.pid", getpid())
276 res = prev_func(request, context)
278 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "")
279 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "")
281 return res
283 return grpc.unary_unary_rpc_method_handler(
284 tracing_function,
285 request_deserializer=handler.request_deserializer,
286 response_serializer=handler.response_serializer,
287 )
290class SessionInterceptor(grpc.ServerInterceptor):
291 """
292 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the
293 function signature by adding another argument.
294 """
296 def intercept_service(self, continuation, handler_call_details):
297 handler = continuation(handler_call_details)
298 prev_func = handler.unary_unary
300 def function_without_session(request, context):
301 with session_scope() as session:
302 return prev_func(request, context, session)
304 return grpc.unary_unary_rpc_method_handler(
305 function_without_session,
306 request_deserializer=handler.request_deserializer,
307 response_serializer=handler.response_serializer,
308 )
311class TracingInterceptor(grpc.ServerInterceptor):
312 """
313 Measures and logs the time it takes to service each incoming call.
314 """
316 def _sanitized_bytes(self, proto):
317 """
318 Remove fields marked sensitive and return serialized bytes
319 """
320 if not proto:
321 return None
323 new_proto = deepcopy(proto)
325 def _sanitize_message(message):
326 for name, descriptor in message.DESCRIPTOR.fields_by_name.items():
327 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]:
328 message.ClearField(name)
329 if descriptor.message_type:
330 submessage = getattr(message, name)
331 if not submessage:
332 continue
333 if descriptor.label == descriptor.LABEL_REPEATED:
334 for msg in submessage:
335 _sanitize_message(msg)
336 else:
337 _sanitize_message(submessage)
339 _sanitize_message(new_proto)
341 return new_proto.SerializeToString()
343 def _store_log(
344 self,
345 method,
346 status_code,
347 duration,
348 user_id,
349 is_api_key,
350 request,
351 response,
352 traceback,
353 perf_report,
354 ip_address,
355 user_agent,
356 ):
357 req_bytes = self._sanitized_bytes(request)
358 res_bytes = self._sanitized_bytes(response)
359 with session_scope() as session:
360 response_truncated = False
361 truncate_res_bytes_length = 16 * 1024 # 16 kB
362 if res_bytes and len(res_bytes) > truncate_res_bytes_length:
363 res_bytes = res_bytes[:truncate_res_bytes_length]
364 response_truncated = True
365 session.add(
366 APICall(
367 is_api_key=is_api_key,
368 method=method,
369 status_code=status_code,
370 duration=duration,
371 user_id=user_id,
372 request=req_bytes,
373 response=res_bytes,
374 response_truncated=response_truncated,
375 traceback=traceback,
376 perf_report=perf_report,
377 ip_address=ip_address,
378 user_agent=user_agent,
379 )
380 )
381 logger.debug(f"{user_id=}, {method=}, {duration=} ms")
383 def intercept_service(self, continuation, handler_call_details):
384 handler = continuation(handler_call_details)
385 prev_func = handler.unary_unary
386 method = handler_call_details.method
388 headers = dict(handler_call_details.invocation_metadata)
389 ip_address = headers.get("x-couchers-real-ip")
390 user_agent = headers.get("user-agent")
392 def tracing_function(request, context):
393 try:
394 with CouchersProfiler(do_profile=False) as prof:
395 start = perf_counter_ns()
396 res = prev_func(request, context)
397 finished = perf_counter_ns()
398 duration = (finished - start) / 1e6 # ms
399 user_id = getattr(context, "user_id", None)
400 is_api_key = getattr(context, "is_api_key", None)
401 self._store_log(
402 method, None, duration, user_id, is_api_key, request, res, None, prof.report, ip_address, user_agent
403 )
404 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000)
405 except Exception as e:
406 finished = perf_counter_ns()
407 duration = (finished - start) / 1e6 # ms
408 code = getattr(context.code(), "name", None)
409 traceback = "".join(format_exception(type(e), e, e.__traceback__))
410 user_id = getattr(context, "user_id", None)
411 is_api_key = getattr(context, "is_api_key", None)
412 self._store_log(
413 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent
414 )
415 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000)
417 if not code:
418 sentry_sdk.set_tag("context", "servicer")
419 sentry_sdk.set_tag("method", method)
420 sentry_sdk.capture_exception(e)
422 raise e
423 return res
425 return grpc.unary_unary_rpc_method_handler(
426 tracing_function,
427 request_deserializer=handler.request_deserializer,
428 response_serializer=handler.response_serializer,
429 )
432class ErrorSanitizationInterceptor(grpc.ServerInterceptor):
433 """
434 If the call resulted in a non-gRPC error, this strips away the error details.
436 It's important to put this first, so that it does not interfere with other interceptors.
437 """
439 def intercept_service(self, continuation, handler_call_details):
440 handler = continuation(handler_call_details)
441 prev_func = handler.unary_unary
443 def sanitizing_function(req, context):
444 try:
445 res = prev_func(req, context)
446 except Exception as e:
447 code = context.code()
448 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None
449 if not code:
450 logger.exception(e)
451 logger.info("Probably an unknown error! Sanitizing...")
452 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR)
453 else:
454 logger.warning(f"RPC error: {code} in method {handler_call_details.method}")
455 raise e
456 return res
458 return grpc.unary_unary_rpc_method_handler(
459 sanitizing_function,
460 request_deserializer=handler.request_deserializer,
461 response_serializer=handler.response_serializer,
462 )