Coverage for src/couchers/interceptors.py: 89%

225 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-22 06:42 +0000

1import logging 

2from copy import deepcopy 

3from datetime import timedelta 

4from os import getpid 

5from threading import get_ident 

6from time import perf_counter_ns 

7from traceback import format_exception 

8 

9import grpc 

10import sentry_sdk 

11from opentelemetry import trace 

12from sqlalchemy.sql import and_, func 

13 

14from couchers import errors 

15from couchers.db import session_scope 

16from couchers.descriptor_pool import get_descriptor_pool 

17from couchers.metrics import observe_in_servicer_duration_histogram 

18from couchers.models import APICall, User, UserActivity, UserSession 

19from couchers.sql import couchers_select as select 

20from couchers.utils import create_session_cookies, now, parse_api_key, parse_session_cookie, parse_user_id_cookie 

21from proto import annotations_pb2 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26def _binned_now(): 

27 return func.date_bin("1 hour", func.now(), "2000-01-01") 

28 

29 

30def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent): 

31 """ 

32 Tries to get session and user info corresponding to this token. 

33 

34 Also updates the user last active time, token last active time, and increments API call count. 

35 """ 

36 if not token: 

37 return None 

38 

39 with session_scope() as session: 

40 result = session.execute( 

41 select(User, UserSession, UserActivity) 

42 .join(User, User.id == UserSession.user_id) 

43 .outerjoin( 

44 UserActivity, 

45 and_( 

46 UserActivity.user_id == User.id, 

47 UserActivity.period == _binned_now(), 

48 UserActivity.ip_address == ip_address, 

49 UserActivity.user_agent == user_agent, 

50 ), 

51 ) 

52 .where(User.is_visible) 

53 .where(UserSession.token == token) 

54 .where(UserSession.is_valid) 

55 .where(UserSession.is_api_key == is_api_key) 

56 ).one_or_none() 

57 

58 if not result: 

59 return None 

60 else: 

61 user, user_session, user_activity = result 

62 

63 # update user last active time if it's been a while 

64 if now() - user.last_active > timedelta(minutes=5): 

65 user.last_active = func.now() 

66 

67 # let's update the token 

68 user_session.last_seen = func.now() 

69 user_session.api_calls += 1 

70 

71 if user_activity: 

72 user_activity.api_calls += 1 

73 else: 

74 session.add( 

75 UserActivity( 

76 user_id=user.id, 

77 period=_binned_now(), 

78 ip_address=ip_address, 

79 user_agent=user_agent, 

80 api_calls=1, 

81 ) 

82 ) 

83 

84 session.commit() 

85 

86 return user.id, user.is_jailed, user.is_superuser, user_session.expiry 

87 

88 

89def abort_handler(message, status_code): 

90 def f(request, context): 

91 context.abort(status_code, message) 

92 

93 return grpc.unary_unary_rpc_method_handler(f) 

94 

95 

96def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

97 return abort_handler(message, status_code) 

98 

99 

100class AuthValidatorInterceptor(grpc.ServerInterceptor): 

101 """ 

102 Extracts a session token from a cookie, and authenticates a user with that. 

103 

104 Sets context.user_id and context.token if authenticated, otherwise 

105 terminates the call with an UNAUTHENTICATED error code. 

106 """ 

107 

108 def __init__(self): 

109 self._pool = get_descriptor_pool() 

110 

111 def intercept_service(self, continuation, handler_call_details): 

112 method = handler_call_details.method 

113 # method is of the form "/org.couchers.api.core.API/GetUser" 

114 _, service_name, method_name = method.split("/") 

115 

116 try: 

117 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

118 except KeyError: 

119 return abort_handler( 

120 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

121 ) 

122 

123 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

124 

125 # if unknown auth level, then it wasn't set and something's wrong 

126 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

127 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

128 

129 assert auth_level in [ 

130 annotations_pb2.AUTH_LEVEL_OPEN, 

131 annotations_pb2.AUTH_LEVEL_JAILED, 

132 annotations_pb2.AUTH_LEVEL_SECURE, 

133 annotations_pb2.AUTH_LEVEL_ADMIN, 

134 ] 

135 

136 headers = dict(handler_call_details.invocation_metadata) 

137 

138 if "cookie" in headers and "authorization" in headers: 

139 # for security reasons, only one of "cookie" or "authorization" can be present 

140 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

141 elif "cookie" in headers: 

142 # the session token is passed in cookies, i.e. in the `cookie` header 

143 token, is_api_key = parse_session_cookie(headers), False 

144 elif "authorization" in headers: 

145 # the session token is passed in the `authorization` header 

146 token, is_api_key = parse_api_key(headers), True 

147 else: 

148 # no session found 

149 token, is_api_key = None, False 

150 

151 ip_address = headers.get("x-couchers-real-ip") 

152 user_agent = headers.get("user-agent") 

153 

154 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

155 # auth_info is now filled if and only if this is a valid session 

156 if not auth_info: 

157 token = None 

158 is_api_key = False 

159 token_expiry = None 

160 user_id = None 

161 

162 # if no session was found and this isn't an open service, fail 

163 if not auth_info: 

164 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

165 return unauthenticated_handler() 

166 else: 

167 # a valid user session was found 

168 user_id, is_jailed, is_superuser, token_expiry = auth_info 

169 

170 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

171 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

172 

173 # if the user is jailed and this is isn't an open or jailed service, fail 

174 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

175 return unauthenticated_handler("Permission denied") 

176 

177 handler = continuation(handler_call_details) 

178 user_aware_function = handler.unary_unary 

179 

180 def user_unaware_function(req, context): 

181 context.user_id = user_id 

182 context.token = (token, token_expiry) 

183 context.is_api_key = is_api_key 

184 return user_aware_function(req, context) 

185 

186 return grpc.unary_unary_rpc_method_handler( 

187 user_unaware_function, 

188 request_deserializer=handler.request_deserializer, 

189 response_serializer=handler.response_serializer, 

190 ) 

191 

192 

193class CookieInterceptor(grpc.ServerInterceptor): 

194 """ 

195 Syncs up the couchers-sesh and couchers-user-id cookies 

196 """ 

197 

198 def intercept_service(self, continuation, handler_call_details): 

199 headers = dict(handler_call_details.invocation_metadata) 

200 cookie_user_id = parse_user_id_cookie(headers) 

201 

202 handler = continuation(handler_call_details) 

203 user_aware_function = handler.unary_unary 

204 

205 def user_unaware_function(req, context): 

206 res = user_aware_function(req, context) 

207 

208 # check the two cookies are in sync 

209 if context.user_id and not context.is_api_key and cookie_user_id != str(context.user_id): 

210 try: 

211 token, expiry = context.token 

212 context.send_initial_metadata( 

213 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)] 

214 ) 

215 except ValueError as e: 

216 logger.info("Tried to send initial metadata but wasn't allowed to") 

217 

218 return res 

219 

220 return grpc.unary_unary_rpc_method_handler( 

221 user_unaware_function, 

222 request_deserializer=handler.request_deserializer, 

223 response_serializer=handler.response_serializer, 

224 ) 

225 

226 

227class ManualAuthValidatorInterceptor(grpc.ServerInterceptor): 

228 """ 

229 Extracts an "Authorization: Bearer <hex>" header and calls the 

230 is_authorized function. Terminates the call with an HTTP error 

231 code if not authorized. 

232 """ 

233 

234 def __init__(self, is_authorized): 

235 self._is_authorized = is_authorized 

236 

237 def intercept_service(self, continuation, handler_call_details): 

238 metadata = dict(handler_call_details.invocation_metadata) 

239 

240 token = parse_api_key(metadata) 

241 

242 if not token or not self._is_authorized(token): 

243 return unauthenticated_handler() 

244 

245 return continuation(handler_call_details) 

246 

247 

248class OTelInterceptor(grpc.ServerInterceptor): 

249 """ 

250 OpenTelemetry tracing 

251 """ 

252 

253 def __init__(self): 

254 self.tracer = trace.get_tracer(__name__) 

255 

256 def intercept_service(self, continuation, handler_call_details): 

257 handler = continuation(handler_call_details) 

258 prev_func = handler.unary_unary 

259 method = handler_call_details.method 

260 

261 # method is of the form "/org.couchers.api.core.API/GetUser" 

262 _, service_name, method_name = method.split("/") 

263 

264 headers = dict(handler_call_details.invocation_metadata) 

265 

266 def tracing_function(request, context): 

267 with self.tracer.start_as_current_span("handler") as rollspan: 

268 rollspan.set_attribute("rpc.method_full", method) 

269 rollspan.set_attribute("rpc.service", service_name) 

270 rollspan.set_attribute("rpc.method", method_name) 

271 

272 rollspan.set_attribute("rpc.thread", get_ident()) 

273 rollspan.set_attribute("rpc.pid", getpid()) 

274 

275 res = prev_func(request, context) 

276 

277 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

278 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

279 

280 return res 

281 

282 return grpc.unary_unary_rpc_method_handler( 

283 tracing_function, 

284 request_deserializer=handler.request_deserializer, 

285 response_serializer=handler.response_serializer, 

286 ) 

287 

288 

289class SessionInterceptor(grpc.ServerInterceptor): 

290 """ 

291 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the 

292 function signature by adding another argument. 

293 """ 

294 

295 def intercept_service(self, continuation, handler_call_details): 

296 handler = continuation(handler_call_details) 

297 prev_func = handler.unary_unary 

298 

299 def function_without_session(request, context): 

300 with session_scope() as session: 

301 return prev_func(request, context, session) 

302 

303 return grpc.unary_unary_rpc_method_handler( 

304 function_without_session, 

305 request_deserializer=handler.request_deserializer, 

306 response_serializer=handler.response_serializer, 

307 ) 

308 

309 

310class TracingInterceptor(grpc.ServerInterceptor): 

311 """ 

312 Measures and logs the time it takes to service each incoming call. 

313 """ 

314 

315 def _sanitized_bytes(self, proto): 

316 """ 

317 Remove fields marked sensitive and return serialized bytes 

318 """ 

319 if not proto: 

320 return None 

321 

322 new_proto = deepcopy(proto) 

323 

324 def _sanitize_message(message): 

325 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

326 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

327 message.ClearField(name) 

328 if descriptor.message_type: 

329 submessage = getattr(message, name) 

330 if not submessage: 

331 continue 

332 if descriptor.label == descriptor.LABEL_REPEATED: 

333 for msg in submessage: 

334 _sanitize_message(msg) 

335 else: 

336 _sanitize_message(submessage) 

337 

338 _sanitize_message(new_proto) 

339 

340 return new_proto.SerializeToString() 

341 

342 def _store_log( 

343 self, 

344 method, 

345 status_code, 

346 duration, 

347 user_id, 

348 is_api_key, 

349 request, 

350 response, 

351 traceback, 

352 perf_report, 

353 ip_address, 

354 user_agent, 

355 ): 

356 req_bytes = self._sanitized_bytes(request) 

357 res_bytes = self._sanitized_bytes(response) 

358 with session_scope() as session: 

359 response_truncated = False 

360 truncate_res_bytes_length = 16 * 1024 # 16 kB 

361 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

362 res_bytes = res_bytes[:truncate_res_bytes_length] 

363 response_truncated = True 

364 session.add( 

365 APICall( 

366 is_api_key=is_api_key, 

367 method=method, 

368 status_code=status_code, 

369 duration=duration, 

370 user_id=user_id, 

371 request=req_bytes, 

372 response=res_bytes, 

373 response_truncated=response_truncated, 

374 traceback=traceback, 

375 perf_report=perf_report, 

376 ip_address=ip_address, 

377 user_agent=user_agent, 

378 ) 

379 ) 

380 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

381 

382 def intercept_service(self, continuation, handler_call_details): 

383 handler = continuation(handler_call_details) 

384 prev_func = handler.unary_unary 

385 method = handler_call_details.method 

386 

387 headers = dict(handler_call_details.invocation_metadata) 

388 ip_address = headers.get("x-couchers-real-ip") 

389 user_agent = headers.get("user-agent") 

390 

391 def tracing_function(request, context): 

392 try: 

393 start = perf_counter_ns() 

394 res = prev_func(request, context) 

395 finished = perf_counter_ns() 

396 duration = (finished - start) / 1e6 # ms 

397 user_id = getattr(context, "user_id", None) 

398 is_api_key = getattr(context, "is_api_key", None) 

399 self._store_log( 

400 method, None, duration, user_id, is_api_key, request, res, None, None, ip_address, user_agent 

401 ) 

402 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000) 

403 except Exception as e: 

404 finished = perf_counter_ns() 

405 duration = (finished - start) / 1e6 # ms 

406 code = getattr(context.code(), "name", None) 

407 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

408 user_id = getattr(context, "user_id", None) 

409 is_api_key = getattr(context, "is_api_key", None) 

410 self._store_log( 

411 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent 

412 ) 

413 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000) 

414 

415 if not code: 

416 sentry_sdk.set_tag("context", "servicer") 

417 sentry_sdk.set_tag("method", method) 

418 sentry_sdk.capture_exception(e) 

419 

420 raise e 

421 return res 

422 

423 return grpc.unary_unary_rpc_method_handler( 

424 tracing_function, 

425 request_deserializer=handler.request_deserializer, 

426 response_serializer=handler.response_serializer, 

427 ) 

428 

429 

430class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

431 """ 

432 If the call resulted in a non-gRPC error, this strips away the error details. 

433 

434 It's important to put this first, so that it does not interfere with other interceptors. 

435 """ 

436 

437 def intercept_service(self, continuation, handler_call_details): 

438 handler = continuation(handler_call_details) 

439 prev_func = handler.unary_unary 

440 

441 def sanitizing_function(req, context): 

442 try: 

443 res = prev_func(req, context) 

444 except Exception as e: 

445 code = context.code() 

446 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

447 if not code: 

448 logger.exception(e) 

449 logger.info("Probably an unknown error! Sanitizing...") 

450 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

451 else: 

452 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

453 raise e 

454 return res 

455 

456 return grpc.unary_unary_rpc_method_handler( 

457 sanitizing_function, 

458 request_deserializer=handler.request_deserializer, 

459 response_serializer=handler.response_serializer, 

460 )