Coverage for src/couchers/interceptors.py: 90%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-08-28 14:55 +0000

1import logging 

2from copy import deepcopy 

3from datetime import timedelta 

4from os import getpid 

5from threading import get_ident 

6from time import perf_counter_ns 

7from traceback import format_exception 

8 

9import grpc 

10import sentry_sdk 

11from opentelemetry import trace 

12from sqlalchemy.sql import and_, func 

13 

14from couchers import errors 

15from couchers.context import CouchersContext, make_interactive_user_context, make_media_context 

16from couchers.db import session_scope 

17from couchers.descriptor_pool import get_descriptor_pool 

18from couchers.metrics import observe_in_servicer_duration_histogram 

19from couchers.models import APICall, User, UserActivity, UserSession 

20from couchers.sql import couchers_select as select 

21from couchers.utils import ( 

22 create_lang_cookie, 

23 create_session_cookies, 

24 now, 

25 parse_api_key, 

26 parse_session_cookie, 

27 parse_ui_lang_cookie, 

28 parse_user_id_cookie, 

29) 

30from proto import annotations_pb2 

31 

32logger = logging.getLogger(__name__) 

33 

34 

35def _binned_now(): 

36 return func.date_bin("1 hour", func.now(), "2000-01-01") 

37 

38 

39def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent): 

40 """ 

41 Tries to get session and user info corresponding to this token. 

42 

43 Also updates the user last active time, token last active time, and increments API call count. 

44 """ 

45 if not token: 

46 return None 

47 

48 with session_scope() as session: 

49 result = session.execute( 

50 select(User, UserSession, UserActivity) 

51 .join(User, User.id == UserSession.user_id) 

52 .outerjoin( 

53 UserActivity, 

54 and_( 

55 UserActivity.user_id == User.id, 

56 UserActivity.period == _binned_now(), 

57 UserActivity.ip_address == ip_address, 

58 UserActivity.user_agent == user_agent, 

59 ), 

60 ) 

61 .where(User.is_visible) 

62 .where(UserSession.token == token) 

63 .where(UserSession.is_valid) 

64 .where(UserSession.is_api_key == is_api_key) 

65 ).one_or_none() 

66 

67 if not result: 

68 return None 

69 else: 

70 user, user_session, user_activity = result 

71 

72 # update user last active time if it's been a while 

73 if now() - user.last_active > timedelta(minutes=5): 

74 user.last_active = func.now() 

75 

76 # let's update the token 

77 user_session.last_seen = func.now() 

78 user_session.api_calls += 1 

79 

80 if user_activity: 

81 user_activity.api_calls += 1 

82 else: 

83 session.add( 

84 UserActivity( 

85 user_id=user.id, 

86 period=_binned_now(), 

87 ip_address=ip_address, 

88 user_agent=user_agent, 

89 api_calls=1, 

90 ) 

91 ) 

92 

93 session.commit() 

94 

95 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference 

96 

97 

98def abort_handler(message, status_code): 

99 def f(request, context): 

100 context.abort(status_code, message) 

101 

102 return grpc.unary_unary_rpc_method_handler(f) 

103 

104 

105def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

106 return abort_handler(message, status_code) 

107 

108 

109def _sanitized_bytes(proto): 

110 """ 

111 Remove fields marked sensitive and return serialized bytes 

112 """ 

113 if not proto: 

114 return None 

115 

116 new_proto = deepcopy(proto) 

117 

118 def _sanitize_message(message): 

119 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

120 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

121 message.ClearField(name) 

122 if descriptor.message_type: 

123 submessage = getattr(message, name) 

124 if not submessage: 

125 continue 

126 if descriptor.label == descriptor.LABEL_REPEATED: 

127 for msg in submessage: 

128 _sanitize_message(msg) 

129 else: 

130 _sanitize_message(submessage) 

131 

132 _sanitize_message(new_proto) 

133 

134 return new_proto.SerializeToString() 

135 

136 

137def _store_log( 

138 *, 

139 method, 

140 status_code, 

141 duration, 

142 user_id, 

143 is_api_key, 

144 request, 

145 response, 

146 traceback, 

147 perf_report, 

148 ip_address, 

149 user_agent, 

150): 

151 req_bytes = _sanitized_bytes(request) 

152 res_bytes = _sanitized_bytes(response) 

153 with session_scope() as session: 

154 response_truncated = False 

155 truncate_res_bytes_length = 16 * 1024 # 16 kB 

156 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

157 res_bytes = res_bytes[:truncate_res_bytes_length] 

158 response_truncated = True 

159 session.add( 

160 APICall( 

161 is_api_key=is_api_key, 

162 method=method, 

163 status_code=status_code, 

164 duration=duration, 

165 user_id=user_id, 

166 request=req_bytes, 

167 response=res_bytes, 

168 response_truncated=response_truncated, 

169 traceback=traceback, 

170 perf_report=perf_report, 

171 ip_address=ip_address, 

172 user_agent=user_agent, 

173 ) 

174 ) 

175 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

176 

177 

178class CouchersMiddlewareInterceptor(grpc.ServerInterceptor): 

179 """ 

180 1. Does auth: extracts a session token from a cookie, and authenticates a user with that. 

181 

182 Sets context.user_id and context.token if authenticated, otherwise 

183 terminates the call with an UNAUTHENTICATED error code. 

184 

185 2. Makes sure cookies are in sync. 

186 

187 3. Injects a session to get a database transaction. 

188 

189 4. Measures and logs the time it takes to service each incoming call. 

190 """ 

191 

192 def __init__(self): 

193 self._pool = get_descriptor_pool() 

194 

195 def intercept_service(self, continuation, handler_call_details): 

196 start = perf_counter_ns() 

197 

198 method = handler_call_details.method 

199 # method is of the form "/org.couchers.api.core.API/GetUser" 

200 _, service_name, method_name = method.split("/") 

201 

202 try: 

203 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

204 except KeyError: 

205 return abort_handler( 

206 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

207 ) 

208 

209 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

210 

211 # if unknown auth level, then it wasn't set and something's wrong 

212 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

213 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

214 

215 assert auth_level in [ 

216 annotations_pb2.AUTH_LEVEL_OPEN, 

217 annotations_pb2.AUTH_LEVEL_JAILED, 

218 annotations_pb2.AUTH_LEVEL_SECURE, 

219 annotations_pb2.AUTH_LEVEL_ADMIN, 

220 ] 

221 

222 headers = dict(handler_call_details.invocation_metadata) 

223 

224 if "cookie" in headers and "authorization" in headers: 

225 # for security reasons, only one of "cookie" or "authorization" can be present 

226 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

227 elif "cookie" in headers: 

228 # the session token is passed in cookies, i.e. in the `cookie` header 

229 token, is_api_key = parse_session_cookie(headers), False 

230 elif "authorization" in headers: 

231 # the session token is passed in the `authorization` header 

232 token, is_api_key = parse_api_key(headers), True 

233 else: 

234 # no session found 

235 token, is_api_key = None, False 

236 

237 ip_address = headers.get("x-couchers-real-ip") 

238 user_agent = headers.get("user-agent") 

239 

240 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

241 # auth_info is now filled if and only if this is a valid session 

242 if not auth_info: 

243 token = None 

244 is_api_key = False 

245 token_expiry = None 

246 user_id = None 

247 ui_language_preference = None 

248 

249 # if no session was found and this isn't an open service, fail 

250 if not auth_info: 

251 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

252 return unauthenticated_handler() 

253 else: 

254 # a valid user session was found 

255 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info 

256 

257 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

258 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

259 

260 # if the user is jailed and this is isn't an open or jailed service, fail 

261 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

262 return unauthenticated_handler("Permission denied") 

263 

264 handler = continuation(handler_call_details) 

265 prev_function = handler.unary_unary 

266 

267 def function_without_couchers_stuff(req, grpc_context): 

268 couchers_context: CouchersContext = make_interactive_user_context( 

269 grpc_context=grpc_context, 

270 user_id=user_id, 

271 is_api_key=is_api_key, 

272 token=token, 

273 ui_language_preference=ui_language_preference, 

274 ) 

275 with session_scope() as session: 

276 try: 

277 res = prev_function(req, couchers_context, session) 

278 finished = perf_counter_ns() 

279 duration = (finished - start) / 1e6 # ms 

280 _store_log( 

281 method=method, 

282 status_code=None, 

283 duration=duration, 

284 user_id=couchers_context._user_id, 

285 is_api_key=couchers_context._is_api_key, 

286 request=req, 

287 response=res, 

288 traceback=None, 

289 perf_report=None, 

290 ip_address=ip_address, 

291 user_agent=user_agent, 

292 ) 

293 observe_in_servicer_duration_histogram(method, couchers_context._user_id, "", "", duration / 1000) 

294 except Exception as e: 

295 finished = perf_counter_ns() 

296 duration = (finished - start) / 1e6 # ms 

297 code = getattr(couchers_context._grpc_context.code(), "name", None) 

298 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

299 _store_log( 

300 method=method, 

301 status_code=code, 

302 duration=duration, 

303 user_id=couchers_context._user_id, 

304 is_api_key=couchers_context._is_api_key, 

305 request=req, 

306 response=None, 

307 traceback=traceback, 

308 perf_report=None, 

309 ip_address=ip_address, 

310 user_agent=user_agent, 

311 ) 

312 observe_in_servicer_duration_histogram( 

313 method, couchers_context._user_id, code or "", type(e).__name__, duration / 1000 

314 ) 

315 

316 if not code: 

317 sentry_sdk.set_tag("context", "servicer") 

318 sentry_sdk.set_tag("method", method) 

319 sentry_sdk.capture_exception(e) 

320 

321 raise e 

322 

323 if user_id and not is_api_key: 

324 cookies = [] 

325 

326 # check the two cookies are in sync & that language preference cookie is correct 

327 if parse_user_id_cookie(headers) != str(user_id): 

328 couchers_context.set_cookies(create_session_cookies(token, user_id, token_expiry)) 

329 if ui_language_preference and ui_language_preference != parse_ui_lang_cookie(headers): 

330 couchers_context.set_cookies(create_lang_cookie(ui_language_preference)) 

331 

332 couchers_context._send_cookies() 

333 

334 return res 

335 

336 return grpc.unary_unary_rpc_method_handler( 

337 function_without_couchers_stuff, 

338 request_deserializer=handler.request_deserializer, 

339 response_serializer=handler.response_serializer, 

340 ) 

341 

342 

343class MediaInterceptor(grpc.ServerInterceptor): 

344 """ 

345 Extracts an "Authorization: Bearer <hex>" header and calls the 

346 is_authorized function. Terminates the call with an HTTP error 

347 code if not authorized. 

348 

349 Also adds a session to called APIs. 

350 """ 

351 

352 def __init__(self, is_authorized): 

353 self._is_authorized = is_authorized 

354 

355 def intercept_service(self, continuation, handler_call_details): 

356 handler = continuation(handler_call_details) 

357 prev_func = handler.unary_unary 

358 metadata = dict(handler_call_details.invocation_metadata) 

359 

360 token = parse_api_key(metadata) 

361 

362 if not token or not self._is_authorized(token): 

363 return unauthenticated_handler() 

364 

365 def function_without_session(request, grpc_context): 

366 with session_scope() as session: 

367 return prev_func(request, make_media_context(grpc_context), session) 

368 

369 return grpc.unary_unary_rpc_method_handler( 

370 function_without_session, 

371 request_deserializer=handler.request_deserializer, 

372 response_serializer=handler.response_serializer, 

373 ) 

374 

375 

376class OTelInterceptor(grpc.ServerInterceptor): 

377 """ 

378 OpenTelemetry tracing 

379 """ 

380 

381 def __init__(self): 

382 self.tracer = trace.get_tracer(__name__) 

383 

384 def intercept_service(self, continuation, handler_call_details): 

385 handler = continuation(handler_call_details) 

386 prev_func = handler.unary_unary 

387 method = handler_call_details.method 

388 

389 # method is of the form "/org.couchers.api.core.API/GetUser" 

390 _, service_name, method_name = method.split("/") 

391 

392 headers = dict(handler_call_details.invocation_metadata) 

393 

394 def tracing_function(request, context): 

395 with self.tracer.start_as_current_span("handler") as rollspan: 

396 rollspan.set_attribute("rpc.method_full", method) 

397 rollspan.set_attribute("rpc.service", service_name) 

398 rollspan.set_attribute("rpc.method", method_name) 

399 

400 rollspan.set_attribute("rpc.thread", get_ident()) 

401 rollspan.set_attribute("rpc.pid", getpid()) 

402 

403 res = prev_func(request, context) 

404 

405 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

406 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

407 

408 return res 

409 

410 return grpc.unary_unary_rpc_method_handler( 

411 tracing_function, 

412 request_deserializer=handler.request_deserializer, 

413 response_serializer=handler.response_serializer, 

414 ) 

415 

416 

417class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

418 """ 

419 If the call resulted in a non-gRPC error, this strips away the error details. 

420 

421 It's important to put this first, so that it does not interfere with other interceptors. 

422 """ 

423 

424 def intercept_service(self, continuation, handler_call_details): 

425 handler = continuation(handler_call_details) 

426 prev_func = handler.unary_unary 

427 

428 def sanitizing_function(req, context): 

429 try: 

430 res = prev_func(req, context) 

431 except Exception as e: 

432 code = context.code() 

433 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

434 if not code: 

435 logger.exception(e) 

436 logger.info("Probably an unknown error! Sanitizing...") 

437 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

438 else: 

439 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

440 raise e 

441 return res 

442 

443 return grpc.unary_unary_rpc_method_handler( 

444 sanitizing_function, 

445 request_deserializer=handler.request_deserializer, 

446 response_serializer=handler.response_serializer, 

447 )