Coverage for src/couchers/interceptors.py: 90%

234 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-04-25 03:06 +0000

1import logging 

2from copy import deepcopy 

3from datetime import timedelta 

4from os import getpid 

5from threading import get_ident 

6from time import perf_counter_ns 

7from traceback import format_exception 

8 

9import grpc 

10import sentry_sdk 

11from opentelemetry import trace 

12from sqlalchemy.sql import and_, func 

13 

14from couchers import errors 

15from couchers.db import session_scope 

16from couchers.descriptor_pool import get_descriptor_pool 

17from couchers.metrics import observe_in_servicer_duration_histogram 

18from couchers.models import APICall, User, UserActivity, UserSession 

19from couchers.sql import couchers_select as select 

20from couchers.utils import ( 

21 create_lang_cookie, 

22 create_session_cookies, 

23 now, 

24 parse_api_key, 

25 parse_session_cookie, 

26 parse_ui_lang_cookie, 

27 parse_user_id_cookie, 

28) 

29from proto import annotations_pb2 

30 

31logger = logging.getLogger(__name__) 

32 

33 

34def _binned_now(): 

35 return func.date_bin("1 hour", func.now(), "2000-01-01") 

36 

37 

38def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent): 

39 """ 

40 Tries to get session and user info corresponding to this token. 

41 

42 Also updates the user last active time, token last active time, and increments API call count. 

43 """ 

44 if not token: 

45 return None 

46 

47 with session_scope() as session: 

48 result = session.execute( 

49 select(User, UserSession, UserActivity) 

50 .join(User, User.id == UserSession.user_id) 

51 .outerjoin( 

52 UserActivity, 

53 and_( 

54 UserActivity.user_id == User.id, 

55 UserActivity.period == _binned_now(), 

56 UserActivity.ip_address == ip_address, 

57 UserActivity.user_agent == user_agent, 

58 ), 

59 ) 

60 .where(User.is_visible) 

61 .where(UserSession.token == token) 

62 .where(UserSession.is_valid) 

63 .where(UserSession.is_api_key == is_api_key) 

64 ).one_or_none() 

65 

66 if not result: 

67 return None 

68 else: 

69 user, user_session, user_activity = result 

70 

71 # update user last active time if it's been a while 

72 if now() - user.last_active > timedelta(minutes=5): 

73 user.last_active = func.now() 

74 

75 # let's update the token 

76 user_session.last_seen = func.now() 

77 user_session.api_calls += 1 

78 

79 if user_activity: 

80 user_activity.api_calls += 1 

81 else: 

82 session.add( 

83 UserActivity( 

84 user_id=user.id, 

85 period=_binned_now(), 

86 ip_address=ip_address, 

87 user_agent=user_agent, 

88 api_calls=1, 

89 ) 

90 ) 

91 

92 session.commit() 

93 

94 return user.id, user.is_jailed, user.is_superuser, user_session.expiry, user.ui_language_preference 

95 

96 

97def abort_handler(message, status_code): 

98 def f(request, context): 

99 context.abort(status_code, message) 

100 

101 return grpc.unary_unary_rpc_method_handler(f) 

102 

103 

104def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

105 return abort_handler(message, status_code) 

106 

107 

108class AuthValidatorInterceptor(grpc.ServerInterceptor): 

109 """ 

110 Extracts a session token from a cookie, and authenticates a user with that. 

111 

112 Sets context.user_id and context.token if authenticated, otherwise 

113 terminates the call with an UNAUTHENTICATED error code. 

114 """ 

115 

116 def __init__(self): 

117 self._pool = get_descriptor_pool() 

118 

119 def intercept_service(self, continuation, handler_call_details): 

120 method = handler_call_details.method 

121 # method is of the form "/org.couchers.api.core.API/GetUser" 

122 _, service_name, method_name = method.split("/") 

123 

124 try: 

125 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

126 except KeyError: 

127 return abort_handler( 

128 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

129 ) 

130 

131 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

132 

133 # if unknown auth level, then it wasn't set and something's wrong 

134 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

135 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

136 

137 assert auth_level in [ 

138 annotations_pb2.AUTH_LEVEL_OPEN, 

139 annotations_pb2.AUTH_LEVEL_JAILED, 

140 annotations_pb2.AUTH_LEVEL_SECURE, 

141 annotations_pb2.AUTH_LEVEL_ADMIN, 

142 ] 

143 

144 headers = dict(handler_call_details.invocation_metadata) 

145 

146 if "cookie" in headers and "authorization" in headers: 

147 # for security reasons, only one of "cookie" or "authorization" can be present 

148 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

149 elif "cookie" in headers: 

150 # the session token is passed in cookies, i.e. in the `cookie` header 

151 token, is_api_key = parse_session_cookie(headers), False 

152 elif "authorization" in headers: 

153 # the session token is passed in the `authorization` header 

154 token, is_api_key = parse_api_key(headers), True 

155 else: 

156 # no session found 

157 token, is_api_key = None, False 

158 

159 ip_address = headers.get("x-couchers-real-ip") 

160 user_agent = headers.get("user-agent") 

161 

162 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

163 # auth_info is now filled if and only if this is a valid session 

164 if not auth_info: 

165 token = None 

166 is_api_key = False 

167 token_expiry = None 

168 user_id = None 

169 ui_language_preference = None 

170 

171 # if no session was found and this isn't an open service, fail 

172 if not auth_info: 

173 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

174 return unauthenticated_handler() 

175 else: 

176 # a valid user session was found 

177 user_id, is_jailed, is_superuser, token_expiry, ui_language_preference = auth_info 

178 

179 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

180 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

181 

182 # if the user is jailed and this is isn't an open or jailed service, fail 

183 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

184 return unauthenticated_handler("Permission denied") 

185 

186 handler = continuation(handler_call_details) 

187 user_aware_function = handler.unary_unary 

188 

189 def user_unaware_function(req, context): 

190 context.user_id = user_id 

191 context.token = (token, token_expiry) 

192 context.is_api_key = is_api_key 

193 context.ui_language_preference = ui_language_preference 

194 return user_aware_function(req, context) 

195 

196 return grpc.unary_unary_rpc_method_handler( 

197 user_unaware_function, 

198 request_deserializer=handler.request_deserializer, 

199 response_serializer=handler.response_serializer, 

200 ) 

201 

202 

203class CookieInterceptor(grpc.ServerInterceptor): 

204 """ 

205 Syncs up the couchers-sesh and couchers-user-id cookies & sets lang cookie 

206 """ 

207 

208 def intercept_service(self, continuation, handler_call_details): 

209 headers = dict(handler_call_details.invocation_metadata) 

210 cookie_user_id = parse_user_id_cookie(headers) 

211 cookie_ui_lang = parse_ui_lang_cookie(headers) 

212 

213 handler = continuation(handler_call_details) 

214 user_aware_function = handler.unary_unary 

215 

216 def user_unaware_function(req, context): 

217 res = user_aware_function(req, context) 

218 

219 if context.user_id and not context.is_api_key: 

220 cookies = [] 

221 

222 # check the two cookies are in sync & that language preference cookie is correct 

223 token, expiry = context.token 

224 if cookie_user_id != str(context.user_id): 

225 cookies.extend( 

226 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)] 

227 ) 

228 if context.ui_language_preference and context.ui_language_preference != cookie_ui_lang: 

229 cookies.extend( 

230 [("set-cookie", cookie) for cookie in create_lang_cookie(context.ui_language_preference)] 

231 ) 

232 

233 if cookies: 

234 try: 

235 context.send_initial_metadata(cookies) 

236 except ValueError as e: 

237 logger.info("Tried to send initial metadata but wasn't allowed to") 

238 

239 return res 

240 

241 return grpc.unary_unary_rpc_method_handler( 

242 user_unaware_function, 

243 request_deserializer=handler.request_deserializer, 

244 response_serializer=handler.response_serializer, 

245 ) 

246 

247 

248class ManualAuthValidatorInterceptor(grpc.ServerInterceptor): 

249 """ 

250 Extracts an "Authorization: Bearer <hex>" header and calls the 

251 is_authorized function. Terminates the call with an HTTP error 

252 code if not authorized. 

253 """ 

254 

255 def __init__(self, is_authorized): 

256 self._is_authorized = is_authorized 

257 

258 def intercept_service(self, continuation, handler_call_details): 

259 metadata = dict(handler_call_details.invocation_metadata) 

260 

261 token = parse_api_key(metadata) 

262 

263 if not token or not self._is_authorized(token): 

264 return unauthenticated_handler() 

265 

266 return continuation(handler_call_details) 

267 

268 

269class OTelInterceptor(grpc.ServerInterceptor): 

270 """ 

271 OpenTelemetry tracing 

272 """ 

273 

274 def __init__(self): 

275 self.tracer = trace.get_tracer(__name__) 

276 

277 def intercept_service(self, continuation, handler_call_details): 

278 handler = continuation(handler_call_details) 

279 prev_func = handler.unary_unary 

280 method = handler_call_details.method 

281 

282 # method is of the form "/org.couchers.api.core.API/GetUser" 

283 _, service_name, method_name = method.split("/") 

284 

285 headers = dict(handler_call_details.invocation_metadata) 

286 

287 def tracing_function(request, context): 

288 with self.tracer.start_as_current_span("handler") as rollspan: 

289 rollspan.set_attribute("rpc.method_full", method) 

290 rollspan.set_attribute("rpc.service", service_name) 

291 rollspan.set_attribute("rpc.method", method_name) 

292 

293 rollspan.set_attribute("rpc.thread", get_ident()) 

294 rollspan.set_attribute("rpc.pid", getpid()) 

295 

296 res = prev_func(request, context) 

297 

298 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

299 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

300 

301 return res 

302 

303 return grpc.unary_unary_rpc_method_handler( 

304 tracing_function, 

305 request_deserializer=handler.request_deserializer, 

306 response_serializer=handler.response_serializer, 

307 ) 

308 

309 

310class SessionInterceptor(grpc.ServerInterceptor): 

311 """ 

312 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the 

313 function signature by adding another argument. 

314 """ 

315 

316 def intercept_service(self, continuation, handler_call_details): 

317 handler = continuation(handler_call_details) 

318 prev_func = handler.unary_unary 

319 

320 def function_without_session(request, context): 

321 with session_scope() as session: 

322 return prev_func(request, context, session) 

323 

324 return grpc.unary_unary_rpc_method_handler( 

325 function_without_session, 

326 request_deserializer=handler.request_deserializer, 

327 response_serializer=handler.response_serializer, 

328 ) 

329 

330 

331class TracingInterceptor(grpc.ServerInterceptor): 

332 """ 

333 Measures and logs the time it takes to service each incoming call. 

334 """ 

335 

336 def _sanitized_bytes(self, proto): 

337 """ 

338 Remove fields marked sensitive and return serialized bytes 

339 """ 

340 if not proto: 

341 return None 

342 

343 new_proto = deepcopy(proto) 

344 

345 def _sanitize_message(message): 

346 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

347 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

348 message.ClearField(name) 

349 if descriptor.message_type: 

350 submessage = getattr(message, name) 

351 if not submessage: 

352 continue 

353 if descriptor.label == descriptor.LABEL_REPEATED: 

354 for msg in submessage: 

355 _sanitize_message(msg) 

356 else: 

357 _sanitize_message(submessage) 

358 

359 _sanitize_message(new_proto) 

360 

361 return new_proto.SerializeToString() 

362 

363 def _store_log( 

364 self, 

365 method, 

366 status_code, 

367 duration, 

368 user_id, 

369 is_api_key, 

370 request, 

371 response, 

372 traceback, 

373 perf_report, 

374 ip_address, 

375 user_agent, 

376 ): 

377 req_bytes = self._sanitized_bytes(request) 

378 res_bytes = self._sanitized_bytes(response) 

379 with session_scope() as session: 

380 response_truncated = False 

381 truncate_res_bytes_length = 16 * 1024 # 16 kB 

382 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

383 res_bytes = res_bytes[:truncate_res_bytes_length] 

384 response_truncated = True 

385 session.add( 

386 APICall( 

387 is_api_key=is_api_key, 

388 method=method, 

389 status_code=status_code, 

390 duration=duration, 

391 user_id=user_id, 

392 request=req_bytes, 

393 response=res_bytes, 

394 response_truncated=response_truncated, 

395 traceback=traceback, 

396 perf_report=perf_report, 

397 ip_address=ip_address, 

398 user_agent=user_agent, 

399 ) 

400 ) 

401 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

402 

403 def intercept_service(self, continuation, handler_call_details): 

404 handler = continuation(handler_call_details) 

405 prev_func = handler.unary_unary 

406 method = handler_call_details.method 

407 

408 headers = dict(handler_call_details.invocation_metadata) 

409 ip_address = headers.get("x-couchers-real-ip") 

410 user_agent = headers.get("user-agent") 

411 

412 def tracing_function(request, context): 

413 try: 

414 start = perf_counter_ns() 

415 res = prev_func(request, context) 

416 finished = perf_counter_ns() 

417 duration = (finished - start) / 1e6 # ms 

418 user_id = getattr(context, "user_id", None) 

419 is_api_key = getattr(context, "is_api_key", None) 

420 self._store_log( 

421 method, None, duration, user_id, is_api_key, request, res, None, None, ip_address, user_agent 

422 ) 

423 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000) 

424 except Exception as e: 

425 finished = perf_counter_ns() 

426 duration = (finished - start) / 1e6 # ms 

427 code = getattr(context.code(), "name", None) 

428 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

429 user_id = getattr(context, "user_id", None) 

430 is_api_key = getattr(context, "is_api_key", None) 

431 self._store_log( 

432 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent 

433 ) 

434 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000) 

435 

436 if not code: 

437 sentry_sdk.set_tag("context", "servicer") 

438 sentry_sdk.set_tag("method", method) 

439 sentry_sdk.capture_exception(e) 

440 

441 raise e 

442 return res 

443 

444 return grpc.unary_unary_rpc_method_handler( 

445 tracing_function, 

446 request_deserializer=handler.request_deserializer, 

447 response_serializer=handler.response_serializer, 

448 ) 

449 

450 

451class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

452 """ 

453 If the call resulted in a non-gRPC error, this strips away the error details. 

454 

455 It's important to put this first, so that it does not interfere with other interceptors. 

456 """ 

457 

458 def intercept_service(self, continuation, handler_call_details): 

459 handler = continuation(handler_call_details) 

460 prev_func = handler.unary_unary 

461 

462 def sanitizing_function(req, context): 

463 try: 

464 res = prev_func(req, context) 

465 except Exception as e: 

466 code = context.code() 

467 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

468 if not code: 

469 logger.exception(e) 

470 logger.info("Probably an unknown error! Sanitizing...") 

471 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

472 else: 

473 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

474 raise e 

475 return res 

476 

477 return grpc.unary_unary_rpc_method_handler( 

478 sanitizing_function, 

479 request_deserializer=handler.request_deserializer, 

480 response_serializer=handler.response_serializer, 

481 )