Coverage for src/couchers/interceptors.py: 89%

227 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-11-04 02:51 +0000

1import logging 

2from copy import deepcopy 

3from datetime import timedelta 

4from os import getpid 

5from threading import get_ident 

6from time import perf_counter_ns 

7from traceback import format_exception 

8 

9import grpc 

10import sentry_sdk 

11from opentelemetry import trace 

12from sqlalchemy.sql import and_, func 

13 

14from couchers import errors 

15from couchers.db import session_scope 

16from couchers.descriptor_pool import get_descriptor_pool 

17from couchers.metrics import observe_in_servicer_duration_histogram 

18from couchers.models import APICall, User, UserActivity, UserSession 

19from couchers.profiler import CouchersProfiler 

20from couchers.sql import couchers_select as select 

21from couchers.utils import create_session_cookies, now, parse_api_key, parse_session_cookie, parse_user_id_cookie 

22from proto import annotations_pb2 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27def _binned_now(): 

28 return func.date_bin("1 hour", func.now(), "2000-01-01") 

29 

30 

31def _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent): 

32 """ 

33 Tries to get session and user info corresponding to this token. 

34 

35 Also updates the user last active time, token last active time, and increments API call count. 

36 """ 

37 if not token: 

38 return None 

39 

40 with session_scope() as session: 

41 result = session.execute( 

42 select(User, UserSession, UserActivity) 

43 .join(User, User.id == UserSession.user_id) 

44 .outerjoin( 

45 UserActivity, 

46 and_( 

47 UserActivity.user_id == User.id, 

48 UserActivity.period == _binned_now(), 

49 UserActivity.ip_address == ip_address, 

50 UserActivity.user_agent == user_agent, 

51 ), 

52 ) 

53 .where(User.is_visible) 

54 .where(UserSession.token == token) 

55 .where(UserSession.is_valid) 

56 .where(UserSession.is_api_key == is_api_key) 

57 ).one_or_none() 

58 

59 if not result: 

60 return None 

61 else: 

62 user, user_session, user_activity = result 

63 

64 # update user last active time if it's been a while 

65 if now() - user.last_active > timedelta(minutes=5): 

66 user.last_active = func.now() 

67 

68 # let's update the token 

69 user_session.last_seen = func.now() 

70 user_session.api_calls += 1 

71 

72 if user_activity: 

73 user_activity.api_calls += 1 

74 else: 

75 session.add( 

76 UserActivity( 

77 user_id=user.id, 

78 period=_binned_now(), 

79 ip_address=ip_address, 

80 user_agent=user_agent, 

81 api_calls=1, 

82 ) 

83 ) 

84 

85 session.commit() 

86 

87 return user.id, user.is_jailed, user.is_superuser, user_session.expiry 

88 

89 

90def abort_handler(message, status_code): 

91 def f(request, context): 

92 context.abort(status_code, message) 

93 

94 return grpc.unary_unary_rpc_method_handler(f) 

95 

96 

97def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

98 return abort_handler(message, status_code) 

99 

100 

101class AuthValidatorInterceptor(grpc.ServerInterceptor): 

102 """ 

103 Extracts a session token from a cookie, and authenticates a user with that. 

104 

105 Sets context.user_id and context.token if authenticated, otherwise 

106 terminates the call with an UNAUTHENTICATED error code. 

107 """ 

108 

109 def __init__(self): 

110 self._pool = get_descriptor_pool() 

111 

112 def intercept_service(self, continuation, handler_call_details): 

113 method = handler_call_details.method 

114 # method is of the form "/org.couchers.api.core.API/GetUser" 

115 _, service_name, method_name = method.split("/") 

116 

117 try: 

118 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

119 except KeyError: 

120 return abort_handler( 

121 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

122 ) 

123 

124 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

125 

126 # if unknown auth level, then it wasn't set and something's wrong 

127 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

128 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

129 

130 assert auth_level in [ 

131 annotations_pb2.AUTH_LEVEL_OPEN, 

132 annotations_pb2.AUTH_LEVEL_JAILED, 

133 annotations_pb2.AUTH_LEVEL_SECURE, 

134 annotations_pb2.AUTH_LEVEL_ADMIN, 

135 ] 

136 

137 headers = dict(handler_call_details.invocation_metadata) 

138 

139 if "cookie" in headers and "authorization" in headers: 

140 # for security reasons, only one of "cookie" or "authorization" can be present 

141 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

142 elif "cookie" in headers: 

143 # the session token is passed in cookies, i.e. in the `cookie` header 

144 token, is_api_key = parse_session_cookie(headers), False 

145 elif "authorization" in headers: 

146 # the session token is passed in the `authorization` header 

147 token, is_api_key = parse_api_key(headers), True 

148 else: 

149 # no session found 

150 token, is_api_key = None, False 

151 

152 ip_address = headers.get("x-couchers-real-ip") 

153 user_agent = headers.get("user-agent") 

154 

155 auth_info = _try_get_and_update_user_details(token, is_api_key, ip_address, user_agent) 

156 # auth_info is now filled if and only if this is a valid session 

157 if not auth_info: 

158 token = None 

159 is_api_key = False 

160 token_expiry = None 

161 user_id = None 

162 

163 # if no session was found and this isn't an open service, fail 

164 if not auth_info: 

165 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

166 return unauthenticated_handler() 

167 else: 

168 # a valid user session was found 

169 user_id, is_jailed, is_superuser, token_expiry = auth_info 

170 

171 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

172 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

173 

174 # if the user is jailed and this is isn't an open or jailed service, fail 

175 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

176 return unauthenticated_handler("Permission denied") 

177 

178 handler = continuation(handler_call_details) 

179 user_aware_function = handler.unary_unary 

180 

181 def user_unaware_function(req, context): 

182 context.user_id = user_id 

183 context.token = (token, token_expiry) 

184 context.is_api_key = is_api_key 

185 return user_aware_function(req, context) 

186 

187 return grpc.unary_unary_rpc_method_handler( 

188 user_unaware_function, 

189 request_deserializer=handler.request_deserializer, 

190 response_serializer=handler.response_serializer, 

191 ) 

192 

193 

194class CookieInterceptor(grpc.ServerInterceptor): 

195 """ 

196 Syncs up the couchers-sesh and couchers-user-id cookies 

197 """ 

198 

199 def intercept_service(self, continuation, handler_call_details): 

200 headers = dict(handler_call_details.invocation_metadata) 

201 cookie_user_id = parse_user_id_cookie(headers) 

202 

203 handler = continuation(handler_call_details) 

204 user_aware_function = handler.unary_unary 

205 

206 def user_unaware_function(req, context): 

207 res = user_aware_function(req, context) 

208 

209 # check the two cookies are in sync 

210 if context.user_id and not context.is_api_key and cookie_user_id != str(context.user_id): 

211 try: 

212 token, expiry = context.token 

213 context.send_initial_metadata( 

214 [("set-cookie", cookie) for cookie in create_session_cookies(token, context.user_id, expiry)] 

215 ) 

216 except ValueError as e: 

217 logger.info("Tried to send initial metadata but wasn't allowed to") 

218 

219 return res 

220 

221 return grpc.unary_unary_rpc_method_handler( 

222 user_unaware_function, 

223 request_deserializer=handler.request_deserializer, 

224 response_serializer=handler.response_serializer, 

225 ) 

226 

227 

228class ManualAuthValidatorInterceptor(grpc.ServerInterceptor): 

229 """ 

230 Extracts an "Authorization: Bearer <hex>" header and calls the 

231 is_authorized function. Terminates the call with an HTTP error 

232 code if not authorized. 

233 """ 

234 

235 def __init__(self, is_authorized): 

236 self._is_authorized = is_authorized 

237 

238 def intercept_service(self, continuation, handler_call_details): 

239 metadata = dict(handler_call_details.invocation_metadata) 

240 

241 token = parse_api_key(metadata) 

242 

243 if not token or not self._is_authorized(token): 

244 return unauthenticated_handler() 

245 

246 return continuation(handler_call_details) 

247 

248 

249class OTelInterceptor(grpc.ServerInterceptor): 

250 """ 

251 OpenTelemetry tracing 

252 """ 

253 

254 def __init__(self): 

255 self.tracer = trace.get_tracer(__name__) 

256 

257 def intercept_service(self, continuation, handler_call_details): 

258 handler = continuation(handler_call_details) 

259 prev_func = handler.unary_unary 

260 method = handler_call_details.method 

261 

262 # method is of the form "/org.couchers.api.core.API/GetUser" 

263 _, service_name, method_name = method.split("/") 

264 

265 headers = dict(handler_call_details.invocation_metadata) 

266 

267 def tracing_function(request, context): 

268 with self.tracer.start_as_current_span("handler") as rollspan: 

269 rollspan.set_attribute("rpc.method_full", method) 

270 rollspan.set_attribute("rpc.service", service_name) 

271 rollspan.set_attribute("rpc.method", method_name) 

272 

273 rollspan.set_attribute("rpc.thread", get_ident()) 

274 rollspan.set_attribute("rpc.pid", getpid()) 

275 

276 res = prev_func(request, context) 

277 

278 rollspan.set_attribute("web.user_agent", headers.get("user-agent") or "") 

279 rollspan.set_attribute("web.ip_address", headers.get("x-couchers-real-ip") or "") 

280 

281 return res 

282 

283 return grpc.unary_unary_rpc_method_handler( 

284 tracing_function, 

285 request_deserializer=handler.request_deserializer, 

286 response_serializer=handler.response_serializer, 

287 ) 

288 

289 

290class SessionInterceptor(grpc.ServerInterceptor): 

291 """ 

292 Adds a session from session_scope() as the last argument. This needs to be the last interceptor since it changes the 

293 function signature by adding another argument. 

294 """ 

295 

296 def intercept_service(self, continuation, handler_call_details): 

297 handler = continuation(handler_call_details) 

298 prev_func = handler.unary_unary 

299 

300 def function_without_session(request, context): 

301 with session_scope() as session: 

302 return prev_func(request, context, session) 

303 

304 return grpc.unary_unary_rpc_method_handler( 

305 function_without_session, 

306 request_deserializer=handler.request_deserializer, 

307 response_serializer=handler.response_serializer, 

308 ) 

309 

310 

311class TracingInterceptor(grpc.ServerInterceptor): 

312 """ 

313 Measures and logs the time it takes to service each incoming call. 

314 """ 

315 

316 def _sanitized_bytes(self, proto): 

317 """ 

318 Remove fields marked sensitive and return serialized bytes 

319 """ 

320 if not proto: 

321 return None 

322 

323 new_proto = deepcopy(proto) 

324 

325 def _sanitize_message(message): 

326 for name, descriptor in message.DESCRIPTOR.fields_by_name.items(): 

327 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

328 message.ClearField(name) 

329 if descriptor.message_type: 

330 submessage = getattr(message, name) 

331 if not submessage: 

332 continue 

333 if descriptor.label == descriptor.LABEL_REPEATED: 

334 for msg in submessage: 

335 _sanitize_message(msg) 

336 else: 

337 _sanitize_message(submessage) 

338 

339 _sanitize_message(new_proto) 

340 

341 return new_proto.SerializeToString() 

342 

343 def _store_log( 

344 self, 

345 method, 

346 status_code, 

347 duration, 

348 user_id, 

349 is_api_key, 

350 request, 

351 response, 

352 traceback, 

353 perf_report, 

354 ip_address, 

355 user_agent, 

356 ): 

357 req_bytes = self._sanitized_bytes(request) 

358 res_bytes = self._sanitized_bytes(response) 

359 with session_scope() as session: 

360 response_truncated = False 

361 truncate_res_bytes_length = 16 * 1024 # 16 kB 

362 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

363 res_bytes = res_bytes[:truncate_res_bytes_length] 

364 response_truncated = True 

365 session.add( 

366 APICall( 

367 is_api_key=is_api_key, 

368 method=method, 

369 status_code=status_code, 

370 duration=duration, 

371 user_id=user_id, 

372 request=req_bytes, 

373 response=res_bytes, 

374 response_truncated=response_truncated, 

375 traceback=traceback, 

376 perf_report=perf_report, 

377 ip_address=ip_address, 

378 user_agent=user_agent, 

379 ) 

380 ) 

381 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

382 

383 def intercept_service(self, continuation, handler_call_details): 

384 handler = continuation(handler_call_details) 

385 prev_func = handler.unary_unary 

386 method = handler_call_details.method 

387 

388 headers = dict(handler_call_details.invocation_metadata) 

389 ip_address = headers.get("x-couchers-real-ip") 

390 user_agent = headers.get("user-agent") 

391 

392 def tracing_function(request, context): 

393 try: 

394 with CouchersProfiler(do_profile=False) as prof: 

395 start = perf_counter_ns() 

396 res = prev_func(request, context) 

397 finished = perf_counter_ns() 

398 duration = (finished - start) / 1e6 # ms 

399 user_id = getattr(context, "user_id", None) 

400 is_api_key = getattr(context, "is_api_key", None) 

401 self._store_log( 

402 method, None, duration, user_id, is_api_key, request, res, None, prof.report, ip_address, user_agent 

403 ) 

404 observe_in_servicer_duration_histogram(method, user_id, "", "", duration / 1000) 

405 except Exception as e: 

406 finished = perf_counter_ns() 

407 duration = (finished - start) / 1e6 # ms 

408 code = getattr(context.code(), "name", None) 

409 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

410 user_id = getattr(context, "user_id", None) 

411 is_api_key = getattr(context, "is_api_key", None) 

412 self._store_log( 

413 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent 

414 ) 

415 observe_in_servicer_duration_histogram(method, user_id, code or "", type(e).__name__, duration / 1000) 

416 

417 if not code: 

418 sentry_sdk.set_tag("context", "servicer") 

419 sentry_sdk.set_tag("method", method) 

420 sentry_sdk.capture_exception(e) 

421 

422 raise e 

423 return res 

424 

425 return grpc.unary_unary_rpc_method_handler( 

426 tracing_function, 

427 request_deserializer=handler.request_deserializer, 

428 response_serializer=handler.response_serializer, 

429 ) 

430 

431 

432class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

433 """ 

434 If the call resulted in a non-gRPC error, this strips away the error details. 

435 

436 It's important to put this first, so that it does not interfere with other interceptors. 

437 """ 

438 

439 def intercept_service(self, continuation, handler_call_details): 

440 handler = continuation(handler_call_details) 

441 prev_func = handler.unary_unary 

442 

443 def sanitizing_function(req, context): 

444 try: 

445 res = prev_func(req, context) 

446 except Exception as e: 

447 code = context.code() 

448 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

449 if not code: 

450 logger.exception(e) 

451 logger.info("Probably an unknown error! Sanitizing...") 

452 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

453 else: 

454 logger.warning(f"RPC error: {code} in method {handler_call_details.method}") 

455 raise e 

456 return res 

457 

458 return grpc.unary_unary_rpc_method_handler( 

459 sanitizing_function, 

460 request_deserializer=handler.request_deserializer, 

461 response_serializer=handler.response_serializer, 

462 )