Coverage for src/couchers/interceptors.py: 99%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

153 statements  

1import logging 

2from copy import deepcopy 

3from time import perf_counter_ns 

4from traceback import format_exception 

5 

6import grpc 

7import sentry_sdk 

8from sqlalchemy.sql import func 

9 

10from couchers import errors 

11from couchers.db import session_scope 

12from couchers.descriptor_pool import get_descriptor_pool 

13from couchers.metrics import servicer_duration_histogram 

14from couchers.models import APICall, User, UserSession 

15from couchers.sql import couchers_select as select 

16from couchers.utils import parse_api_key, parse_session_cookie 

17from proto import annotations_pb2 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def _try_get_and_update_user_details(token, is_api_key): 

23 """ 

24 Tries to get session and user info corresponding to this token. 

25 

26 Also updates the user last active time, token last active time, and increments API call count. 

27 """ 

28 if not token: 

29 return None 

30 

31 with session_scope() as session: 

32 result = session.execute( 

33 select(User, UserSession) 

34 .join(User, User.id == UserSession.user_id) 

35 .where(User.is_visible) 

36 .where(UserSession.token == token) 

37 .where(UserSession.is_valid) 

38 .where(UserSession.is_api_key == is_api_key) 

39 ).one_or_none() 

40 

41 if not result: 

42 return None 

43 else: 

44 user, user_session = result 

45 

46 # update user last active time 

47 user.last_active = func.now() 

48 

49 # let's update the token 

50 user_session.last_seen = func.now() 

51 user_session.api_calls += 1 

52 session.flush() 

53 

54 return user.id, user.is_jailed, user.is_superuser 

55 

56 

57def abort_handler(message, status_code): 

58 def f(request, context): 

59 context.abort(status_code, message) 

60 

61 return grpc.unary_unary_rpc_method_handler(f) 

62 

63 

64def unauthenticated_handler(message="Unauthorized", status_code=grpc.StatusCode.UNAUTHENTICATED): 

65 return abort_handler(message, status_code) 

66 

67 

68class AuthValidatorInterceptor(grpc.ServerInterceptor): 

69 """ 

70 Extracts a session token from a cookie, and authenticates a user with that. 

71 

72 Sets context.user_id and context.token if authenticated, otherwise 

73 terminates the call with an UNAUTHENTICATED error code. 

74 """ 

75 

76 def __init__(self): 

77 self._pool = get_descriptor_pool() 

78 

79 def intercept_service(self, continuation, handler_call_details): 

80 method = handler_call_details.method 

81 # method is of the form "/org.couchers.api.core.API/GetUser" 

82 _, service_name, method_name = method.split("/") 

83 

84 try: 

85 service_options = self._pool.FindServiceByName(service_name).GetOptions() 

86 except KeyError: 

87 return abort_handler( 

88 "API call does not exist. Please refresh and try again.", grpc.StatusCode.UNIMPLEMENTED 

89 ) 

90 

91 auth_level = service_options.Extensions[annotations_pb2.auth_level] 

92 

93 # if unknown auth level, then it wasn't set and something's wrong 

94 if auth_level == annotations_pb2.AUTH_LEVEL_UNKNOWN: 

95 return abort_handler("Internal authentication error.", grpc.StatusCode.INTERNAL) 

96 

97 assert auth_level in [ 

98 annotations_pb2.AUTH_LEVEL_OPEN, 

99 annotations_pb2.AUTH_LEVEL_JAILED, 

100 annotations_pb2.AUTH_LEVEL_SECURE, 

101 annotations_pb2.AUTH_LEVEL_ADMIN, 

102 ] 

103 

104 headers = dict(handler_call_details.invocation_metadata) 

105 

106 if "cookie" in headers and "authorization" in headers: 

107 # for security reasons, only one of "cookie" or "authorization" can be present 

108 return unauthenticated_handler('Both "cookie" and "authorization" in request') 

109 elif "cookie" in headers: 

110 # the session token is passed in cookies, i.e. in the `cookie` header 

111 token = parse_session_cookie(headers) 

112 is_api_key = False 

113 res = _try_get_and_update_user_details(token, is_api_key) 

114 elif "authorization" in headers: 

115 # the session token is passed in the `authorization` header 

116 token = parse_api_key(headers) 

117 is_api_key = True 

118 res = _try_get_and_update_user_details(token, is_api_key) 

119 else: 

120 # no session found 

121 token = None 

122 is_api_key = False 

123 res = None 

124 

125 # if no session was found and this isn't an open service, fail 

126 if not token or not res: 

127 if auth_level != annotations_pb2.AUTH_LEVEL_OPEN: 

128 return unauthenticated_handler() 

129 user_id = None 

130 else: 

131 # a valid user session was found 

132 user_id, is_jailed, is_superuser = res 

133 

134 if auth_level == annotations_pb2.AUTH_LEVEL_ADMIN and not is_superuser: 

135 return unauthenticated_handler("Permission denied", grpc.StatusCode.PERMISSION_DENIED) 

136 

137 # if the user is jailed and this is isn't an open or jailed service, fail 

138 if is_jailed and auth_level not in [annotations_pb2.AUTH_LEVEL_OPEN, annotations_pb2.AUTH_LEVEL_JAILED]: 

139 return unauthenticated_handler("Permission denied") 

140 

141 handler = continuation(handler_call_details) 

142 user_aware_function = handler.unary_unary 

143 

144 def user_unaware_function(req, context): 

145 context.user_id = user_id 

146 context.token = token 

147 context.is_api_key = is_api_key 

148 return user_aware_function(req, context) 

149 

150 return grpc.unary_unary_rpc_method_handler( 

151 user_unaware_function, 

152 request_deserializer=handler.request_deserializer, 

153 response_serializer=handler.response_serializer, 

154 ) 

155 

156 

157class ManualAuthValidatorInterceptor(grpc.ServerInterceptor): 

158 """ 

159 Extracts an "Authorization: Bearer <hex>" header and calls the 

160 is_authorized function. Terminates the call with an HTTP error 

161 code if not authorized. 

162 """ 

163 

164 def __init__(self, is_authorized): 

165 self._is_authorized = is_authorized 

166 

167 def intercept_service(self, continuation, handler_call_details): 

168 metadata = dict(handler_call_details.invocation_metadata) 

169 

170 token = parse_api_key(metadata) 

171 

172 if not token or not self._is_authorized(token): 

173 return unauthenticated_handler() 

174 

175 return continuation(handler_call_details) 

176 

177 

178class TracingInterceptor(grpc.ServerInterceptor): 

179 """ 

180 Measures and logs the time it takes to service each incoming call. 

181 """ 

182 

183 def _sanitized_bytes(self, proto): 

184 """ 

185 Remove fields marked sensitive and return serialized bytes 

186 """ 

187 if not proto: 

188 return None 

189 new_proto = deepcopy(proto) 

190 for name, descriptor in new_proto.DESCRIPTOR.fields_by_name.items(): 

191 if descriptor.GetOptions().Extensions[annotations_pb2.sensitive]: 

192 new_proto.ClearField(name) 

193 return new_proto.SerializeToString() 

194 

195 def _observe_in_histogram(self, method, status_code, exception_type, duration): 

196 servicer_duration_histogram.labels(method, status_code, exception_type).observe(duration) 

197 

198 def _store_log(self, method, status_code, duration, user_id, is_api_key, request, response, traceback): 

199 req_bytes = self._sanitized_bytes(request) 

200 res_bytes = self._sanitized_bytes(response) 

201 with session_scope() as session: 

202 session.add( 

203 APICall( 

204 is_api_key=is_api_key, 

205 method=method, 

206 status_code=status_code, 

207 duration=duration, 

208 user_id=user_id, 

209 request=req_bytes, 

210 response=res_bytes, 

211 traceback=traceback, 

212 ) 

213 ) 

214 logger.debug(f"{user_id=}, {method=}, {duration=} ms") 

215 

216 def intercept_service(self, continuation, handler_call_details): 

217 handler = continuation(handler_call_details) 

218 prev_func = handler.unary_unary 

219 method = handler_call_details.method 

220 

221 def tracing_function(request, context): 

222 try: 

223 start = perf_counter_ns() 

224 res = prev_func(request, context) 

225 finished = perf_counter_ns() 

226 duration = (finished - start) / 1e6 # ms 

227 user_id = getattr(context, "user_id", None) 

228 is_api_key = getattr(context, "is_api_key", None) 

229 self._store_log(method, None, duration, user_id, is_api_key, request, res, None) 

230 self._observe_in_histogram(method, "", "", duration) 

231 except Exception as e: 

232 finished = perf_counter_ns() 

233 duration = (finished - start) / 1e6 # ms 

234 code = getattr(context.code(), "name", None) 

235 traceback = "".join(format_exception(type(e), e, e.__traceback__)) 

236 user_id = getattr(context, "user_id", None) 

237 is_api_key = getattr(context, "is_api_key", None) 

238 self._store_log(method, code, duration, user_id, is_api_key, request, None, traceback) 

239 self._observe_in_histogram(method, code or "", type(e).__name__, duration) 

240 

241 if not code: 

242 sentry_sdk.set_tag("context", "servicer") 

243 sentry_sdk.set_tag("method", method) 

244 sentry_sdk.capture_exception(e) 

245 

246 raise e 

247 return res 

248 

249 return grpc.unary_unary_rpc_method_handler( 

250 tracing_function, 

251 request_deserializer=handler.request_deserializer, 

252 response_serializer=handler.response_serializer, 

253 ) 

254 

255 

256class ErrorSanitizationInterceptor(grpc.ServerInterceptor): 

257 """ 

258 If the call resulted in a non-gRPC error, this strips away the error details. 

259 

260 It's important to put this first, so that it does not interfere with other interceptors. 

261 """ 

262 

263 def intercept_service(self, continuation, handler_call_details): 

264 handler = continuation(handler_call_details) 

265 prev_func = handler.unary_unary 

266 

267 def sanitizing_function(req, context): 

268 try: 

269 res = prev_func(req, context) 

270 except Exception as e: 

271 code = context.code() 

272 # the code is one of the RPC error codes if this was failed through abort(), otherwise it's None 

273 if not code: 

274 logger.exception(e) 

275 logger.info(f"Probably an unknown error! Sanitizing...") 

276 context.abort(grpc.StatusCode.INTERNAL, errors.UNKNOWN_ERROR) 

277 else: 

278 logger.warning(f"RPC error: {code}") 

279 raise e 

280 return res 

281 

282 return grpc.unary_unary_rpc_method_handler( 

283 sanitizing_function, 

284 request_deserializer=handler.request_deserializer, 

285 response_serializer=handler.response_serializer, 

286 )