Coverage for src/couchers/interceptors.py: 98%

165 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-07-22 16:44 +0000

1import logging 

2from copy import deepcopy 

3from datetime import timedelta 

4from time import perf_counter_ns 

5from traceback import format_exception 

6 

7import grpc 

8import sentry_sdk 

9from sqlalchemy.sql import func 

10 

11from couchers import errors 

12from couchers.db import session_scope 

13from couchers.descriptor_pool import get_descriptor_pool 

14from couchers.metrics import servicer_duration_histogram 

15from couchers.models import APICall, User, UserSession 

16from couchers.profiler import CouchersProfiler 

17from couchers.sql import couchers_select as select 

18from couchers.utils import now, parse_api_key, parse_session_cookie 

19from proto import annotations_pb2 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24def _try_get_and_update_user_details(token, is_api_key): 

25 """ 

26 Tries to get session and user info corresponding to this token. 

27 

28 Also updates the user last active time, token last active time, and increments API call count. 

29 """ 

30 if not token: 

31 return None 

32 

33 with session_scope() as session: 

34 result = session.execute( 

35 select(User, UserSession) 

36 .join(User, User.id == UserSession.user_id) 

37 .where(User.is_visible) 

38 .where(UserSession.token == token) 

39 .where(UserSession.is_valid) 

40 .where(UserSession.is_api_key == is_api_key) 

41 ).one_or_none() 

42 

43 if not result: 

44 return None 

45 else: 

46 user, user_session = result 

47 

48 # update user last active time if it's been a while 

49 if now() - user.last_active > timedelta(minutes=5): 

50 user.last_active = func.now() 

51 

52 # let's update the token 

53 user_session.last_seen = func.now() 

54 user_session.api_calls += 1 

55 session.flush() 

56 

57 return user.id, user.is_jailed, user.is_superuser 

58 

59 

60def abort_handler(message, status_code): 

61 def f(request, context): 

62 context.abort(status_code, message) 

63 

64 return grpc.unary_unary_rpc_method_handler(f) 

65 

66 

67def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

68 return abort_handler(message, status_code) 

69 

70 

71class AuthValidatorInterceptor(grpc.ServerInterceptor): 

72 """ 

73 Extracts a session token from a cookie, and authenticates a user with that. 

74 

75 Sets context.user_id and context.token if authenticated, otherwise 

76 terminates the call with an UNAUTHENTICATED error code. 

77 """ 

78 

79 def __init__(self): 

80 self._pool = get_descriptor_pool() 

81 

82 def intercept_service(self, continuation, handler_call_details): 

83 method = handler_call_details.method 

84 # method is of the form "/org.couchers.api.core.API/GetUser" 

85 _, service_name, method_name = method.split("/") 

86 

87 try: 

88 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

89 except KeyError: 

90 return abort_handler( 

91 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

92 ) 

93 

94 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

95 

96 # if unknown auth level, then it wasn't set and something's wrong 

97 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

98 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

99 

100 assert auth_level in [ 

101 annotations_pb2.AUTH_LEVEL_OPEN, 

102 annotations_pb2.AUTH_LEVEL_JAILED, 

103 annotations_pb2.AUTH_LEVEL_SECURE, 

104 annotations_pb2.AUTH_LEVEL_ADMIN, 

105 ] 

106 

107 headers = dict(handler_call_details.invocation_metadata) 

108 

109 if "cookie" in headers and "authorization" in headers: 

110 # for security reasons, only one of "cookie" or "authorization" can be present 

111 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

112 elif "cookie" in headers: 

113 # the session token is passed in cookies, i.e. in the `cookie` header 

114 token = parse_session_cookie(headers) 

115 is_api_key = False 

116 res = _try_get_and_update_user_details(token, is_api_key) 

117 elif "authorization" in headers: 

118 # the session token is passed in the `authorization` header 

119 token = parse_api_key(headers) 

120 is_api_key = True 

121 res = _try_get_and_update_user_details(token, is_api_key) 

122 else: 

123 # no session found 

124 token = None 

125 is_api_key = False 

126 res = None 

127 

128 # if no session was found and this isn't an open service, fail 

129 if not token or not res: 

130 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

131 return unauthenticated_handler() 

132 user_id = None 

133 else: 

134 # a valid user session was found 

135 user_id, is_jailed, is_superuser = res 

136 

137 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

138 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

139 

140 # if the user is jailed and this is isn't an open or jailed service, fail 

141 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

142 return unauthenticated_handler("Permission denied") 

143 

144 handler = continuation(handler_call_details) 

145 user_aware_function = handler.unary_unary 

146 

147 def user_unaware_function(req, context): 

148 context.user_id = user_id 

149 context.token = token 

150 context.is_api_key = is_api_key 

151 return user_aware_function(req, context) 

152 

153 return grpc.unary_unary_rpc_method_handler( 

154 user_unaware_function, 

155 request_deserializer=handler.request_deserializer, 

156 response_serializer=handler.response_serializer, 

157 ) 

158 

159 

160class ManualAuthValidatorInterceptor(grpc.ServerInterceptor): 

161 """ 

162 Extracts an "Authorization: Bearer <hex>" header and calls the 

163 is_authorized function. Terminates the call with an HTTP error 

164 code if not authorized. 

165 """ 

166 

167 def __init__(self, is_authorized): 

168 self._is_authorized = is_authorized 

169 

170 def intercept_service(self, continuation, handler_call_details): 

171 metadata = dict(handler_call_details.invocation_metadata) 

172 

173 token = parse_api_key(metadata) 

174 

175 if not token or not self._is_authorized(token): 

176 return unauthenticated_handler() 

177 

178 return continuation(handler_call_details) 

179 

180 

181class TracingInterceptor(grpc.ServerInterceptor): 

182 """ 

183 Measures and logs the time it takes to service each incoming call. 

184 """ 

185 

186 def _sanitized_bytes(self, proto): 

187 """ 

188 Remove fields marked sensitive and return serialized bytes 

189 """ 

190 if not proto: 

191 return None 

192 new_proto = deepcopy(proto) 

193 for name, descriptor in new_proto.DESCRIPTOR.fields_by_name.items(): 

194 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

195 new_proto.ClearField(name) 

196 return new_proto.SerializeToString() 

197 

198 def _observe_in_histogram(self, method, status_code, exception_type, duration): 

199 servicer_duration_histogram.labels(method, status_code, exception_type).observe(duration) 

200 

201 def _store_log( 

202 self, 

203 method, 

204 status_code, 

205 duration, 

206 user_id, 

207 is_api_key, 

208 request, 

209 response, 

210 traceback, 

211 perf_report, 

212 ip_address, 

213 user_agent, 

214 ): 

215 req_bytes = self._sanitized_bytes(request) 

216 res_bytes = self._sanitized_bytes(response) 

217 with session_scope() as session: 

218 response_truncated = False 

219 truncate_res_bytes_length = 16 * 1024 # 16 kB 

220 if res_bytes and len(res_bytes) > truncate_res_bytes_length: 

221 res_bytes = res_bytes[:truncate_res_bytes_length] 

222 response_truncated = True 

223 session.add( 

224 APICall( 

225 is_api_key=is_api_key, 

226 method=method, 

227 status_code=status_code, 

228 duration=duration, 

229 user_id=user_id, 

230 request=req_bytes, 

231 response=res_bytes, 

232 response_truncated=response_truncated, 

233 traceback=traceback, 

234 perf_report=perf_report, 

235 ip_address=ip_address, 

236 user_agent=user_agent, 

237 ) 

238 ) 

239 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

240 

241 def intercept_service(self, continuation, handler_call_details): 

242 handler = continuation(handler_call_details) 

243 prev_func = handler.unary_unary 

244 method = handler_call_details.method 

245 

246 headers = dict(handler_call_details.invocation_metadata) 

247 ip_address = headers.get("x-couchers-real-ip") 

248 user_agent = headers.get("user-agent") 

249 

250 def tracing_function(request, context): 

251 try: 

252 with CouchersProfiler(do_profile=False) as prof: 

253 start = perf_counter_ns() 

254 res = prev_func(request, context) 

255 finished = perf_counter_ns() 

256 duration = (finished - start) / 1e6 # ms 

257 user_id = getattr(context, "user_id", None) 

258 is_api_key = getattr(context, "is_api_key", None) 

259 self._store_log( 

260 method, None, duration, user_id, is_api_key, request, res, None, prof.report, ip_address, user_agent 

261 ) 

262 self._observe_in_histogram(method, "", "", duration) 

263 except Exception as e: 

264 finished = perf_counter_ns() 

265 duration = (finished - start) / 1e6 # ms 

266 code = getattr(context.code(), "name", None) 

267 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

268 user_id = getattr(context, "user_id", None) 

269 is_api_key = getattr(context, "is_api_key", None) 

270 self._store_log( 

271 method, code, duration, user_id, is_api_key, request, None, traceback, None, ip_address, user_agent 

272 ) 

273 self._observe_in_histogram(method, code or "", type(e).__name__, duration) 

274 

275 if not code: 

276 sentry_sdk.set_tag("context", "servicer") 

277 sentry_sdk.set_tag("method", method) 

278 sentry_sdk.capture_exception(e) 

279 

280 raise e 

281 return res 

282 

283 return grpc.unary_unary_rpc_method_handler( 

284 tracing_function, 

285 request_deserializer=handler.request_deserializer, 

286 response_serializer=handler.response_serializer, 

287 ) 

288 

289 

290class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

291 """ 

292 If the call resulted in a non-gRPC error, this strips away the error details. 

293 

294 It's important to put this first, so that it does not interfere with other interceptors. 

295 """ 

296 

297 def intercept_service(self, continuation, handler_call_details): 

298 handler = continuation(handler_call_details) 

299 prev_func = handler.unary_unary 

300 

301 def sanitizing_function(req, context): 

302 try: 

303 res = prev_func(req, context) 

304 except Exception as e: 

305 code = context.code() 

306 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

307 if not code: 

308 logger.exception(e) 

309 logger.info("Probably an unknown error! Sanitizing...") 

310 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

311 else: 

312 logger.warning(f"RPC error: {code}") 

313 raise e 

314 return res 

315 

316 return grpc.unary_unary_rpc_method_handler( 

317 sanitizing_function, 

318 request_deserializer=handler.request_deserializer, 

319 response_serializer=handler.response_serializer, 

320 )