Coverage for src/tests/test_interceptors.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from concurrent import futures
2from contextlib import contextmanager
4import grpc
5import pytest
6from google.protobuf import empty_pb2
8from couchers import errors
9from couchers.crypto import random_hex
10from couchers.db import session_scope
11from couchers.interceptors import AuthValidatorInterceptor, ErrorSanitizationInterceptor, TracingInterceptor
12from couchers.metrics import CODE_LABEL, EXCEPTION_LABEL, METHOD_LABEL, servicer_duration_histogram
13from couchers.models import APICall, UserSession
14from couchers.servicers.account import Account
15from couchers.sql import couchers_select as select
16from proto import account_pb2, admin_pb2, auth_pb2
17from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa
20@pytest.fixture(autouse=True)
21def _(testconfig):
22 pass
25@contextmanager
26def interceptor_dummy_api(
27 rpc,
28 interceptors,
29 service_name="testing.Test",
30 method_name="TestRpc",
31 request_type=empty_pb2.Empty,
32 response_type=empty_pb2.Empty,
33 creds=None,
34):
35 with futures.ThreadPoolExecutor(1) as executor:
36 server = grpc.server(executor, interceptors=interceptors)
37 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
39 # manually add the handler
40 rpc_method_handlers = {
41 method_name: grpc.unary_unary_rpc_method_handler(
42 rpc,
43 request_deserializer=request_type.FromString,
44 response_serializer=response_type.SerializeToString,
45 )
46 }
47 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
48 server.add_generic_rpc_handlers((generic_handler,))
49 server.start()
51 try:
52 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
53 call_rpc = channel.unary_unary(
54 f"/{service_name}/{method_name}",
55 request_serializer=request_type.SerializeToString,
56 response_deserializer=response_type.FromString,
57 )
58 yield call_rpc
59 finally:
60 server.stop(None).wait()
63def _check_histogram_labels(method, exception, code, count):
64 metrics = servicer_duration_histogram.collect()
65 servicer_histogram = [m for m in metrics if m.name == "servicer_duration"][0]
66 histogram_count = [
67 s
68 for s in servicer_histogram.samples
69 if s.name == "servicer_duration_count"
70 and s.labels[METHOD_LABEL] == method
71 and s.labels[EXCEPTION_LABEL] == exception
72 and s.labels[CODE_LABEL] == code
73 ][0]
74 assert histogram_count.value == count
75 servicer_duration_histogram.clear()
78def test_logging_interceptor_ok():
79 def TestRpc(request, context):
80 return empty_pb2.Empty()
82 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
83 call_rpc(empty_pb2.Empty())
86def test_logging_interceptor_all_ignored():
87 # error codes that should not be touched by the interceptor
88 pass_through_status_codes = [
89 # we can't abort with OK
90 # grpc.StatusCode.OK,
91 grpc.StatusCode.CANCELLED,
92 grpc.StatusCode.UNKNOWN,
93 grpc.StatusCode.INVALID_ARGUMENT,
94 grpc.StatusCode.DEADLINE_EXCEEDED,
95 grpc.StatusCode.NOT_FOUND,
96 grpc.StatusCode.ALREADY_EXISTS,
97 grpc.StatusCode.PERMISSION_DENIED,
98 grpc.StatusCode.UNAUTHENTICATED,
99 grpc.StatusCode.RESOURCE_EXHAUSTED,
100 grpc.StatusCode.FAILED_PRECONDITION,
101 grpc.StatusCode.ABORTED,
102 grpc.StatusCode.OUT_OF_RANGE,
103 grpc.StatusCode.UNIMPLEMENTED,
104 grpc.StatusCode.INTERNAL,
105 grpc.StatusCode.UNAVAILABLE,
106 grpc.StatusCode.DATA_LOSS,
107 ]
109 for status_code in pass_through_status_codes:
110 message = random_hex()
112 def TestRpc(request, context):
113 context.abort(status_code, message)
115 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
116 with pytest.raises(grpc.RpcError) as e:
117 call_rpc(empty_pb2.Empty())
118 assert e.value.code() == status_code
119 assert e.value.details() == message
122def test_logging_interceptor_assertion():
123 def TestRpc(request, context):
124 assert False
126 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
127 with pytest.raises(grpc.RpcError) as e:
128 call_rpc(empty_pb2.Empty())
129 assert e.value.code() == grpc.StatusCode.INTERNAL
130 assert e.value.details() == errors.UNKNOWN_ERROR
133def test_logging_interceptor_div0():
134 def TestRpc(request, context):
135 1 / 0
137 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
138 with pytest.raises(grpc.RpcError) as e:
139 call_rpc(empty_pb2.Empty())
140 assert e.value.code() == grpc.StatusCode.INTERNAL
141 assert e.value.details() == errors.UNKNOWN_ERROR
144def test_logging_interceptor_raise():
145 def TestRpc(request, context):
146 raise Exception()
148 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
149 with pytest.raises(grpc.RpcError) as e:
150 call_rpc(empty_pb2.Empty())
151 assert e.value.code() == grpc.StatusCode.INTERNAL
152 assert e.value.details() == errors.UNKNOWN_ERROR
155def test_logging_interceptor_raise_custom():
156 class _TestingException(Exception):
157 pass
159 def TestRpc(request, context):
160 raise _TestingException("This is a custom exception")
162 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
163 with pytest.raises(grpc.RpcError) as e:
164 call_rpc(empty_pb2.Empty())
165 assert e.value.code() == grpc.StatusCode.INTERNAL
166 assert e.value.details() == errors.UNKNOWN_ERROR
169def test_tracing_interceptor_ok_open(db):
170 def TestRpc(request, context):
171 return empty_pb2.Empty()
173 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc:
174 call_rpc(empty_pb2.Empty())
176 with session_scope() as session:
177 trace = session.execute(select(APICall)).scalar_one()
178 assert trace.method == "/testing.Test/TestRpc"
179 assert not trace.status_code
180 assert not trace.user_id
181 assert len(trace.request) == 0
182 assert len(trace.response) == 0
183 assert not trace.traceback
185 _check_histogram_labels("/testing.Test/TestRpc", "", "", 1)
188def test_tracing_interceptor_sensitive(db):
189 def TestRpc(request, context):
190 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
192 with interceptor_dummy_api(
193 TestRpc,
194 interceptors=[TracingInterceptor()],
195 request_type=auth_pb2.SignupAccount,
196 response_type=auth_pb2.AuthReq,
197 ) as call_rpc:
198 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
200 with session_scope() as session:
201 trace = session.execute(select(APICall)).scalar_one()
202 assert trace.method == "/testing.Test/TestRpc"
203 assert not trace.status_code
204 assert not trace.user_id
205 assert not trace.traceback
206 req = auth_pb2.SignupAccount.FromString(trace.request)
207 assert not req.password
208 assert req.username == "not removed"
209 res = auth_pb2.AuthReq.FromString(trace.response)
210 assert res.user == "this is not secret"
211 assert not res.password
213 _check_histogram_labels("/testing.Test/TestRpc", "", "", 1)
216def test_tracing_interceptor_exception(db):
217 def TestRpc(request, context):
218 raise Exception("Some error message")
220 with interceptor_dummy_api(
221 TestRpc,
222 interceptors=[TracingInterceptor()],
223 request_type=auth_pb2.SignupAccount,
224 response_type=auth_pb2.AuthReq,
225 ) as call_rpc:
226 with pytest.raises(Exception):
227 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
229 with session_scope() as session:
230 trace = session.execute(select(APICall)).scalar_one()
231 assert trace.method == "/testing.Test/TestRpc"
232 assert not trace.status_code
233 assert not trace.user_id
234 assert "Some error message" in trace.traceback
235 req = auth_pb2.SignupAccount.FromString(trace.request)
236 assert not req.password
237 assert req.username == "not removed"
238 assert not trace.response
240 _check_histogram_labels("/testing.Test/TestRpc", "Exception", "", 1)
243def test_tracing_interceptor_abort(db):
244 def TestRpc(request, context):
245 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
247 with interceptor_dummy_api(
248 TestRpc,
249 interceptors=[TracingInterceptor()],
250 request_type=auth_pb2.SignupAccount,
251 response_type=auth_pb2.AuthReq,
252 ) as call_rpc:
253 with pytest.raises(Exception):
254 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
256 with session_scope() as session:
257 trace = session.execute(select(APICall)).scalar_one()
258 assert trace.method == "/testing.Test/TestRpc"
259 assert trace.status_code == "FAILED_PRECONDITION"
260 assert not trace.user_id
261 assert "now a grpc abort" in trace.traceback
262 req = auth_pb2.SignupAccount.FromString(trace.request)
263 assert not req.password
264 assert req.username == "not removed"
265 assert not trace.response
267 _check_histogram_labels("/testing.Test/TestRpc", "Exception", "FAILED_PRECONDITION", 1)
270def test_auth_interceptor(db):
271 super_user, super_token = generate_user(is_superuser=True)
272 user, token = generate_user()
274 with real_admin_session(super_token) as api:
275 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
277 with session_scope() as session:
278 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
279 api_key = api_session.token
281 account = Account()
283 rpc_def = {
284 "rpc": account.GetAccountInfo,
285 "service_name": "org.couchers.api.account.Account",
286 "method_name": "GetAccountInfo",
287 "interceptors": [AuthValidatorInterceptor()],
288 "request_type": empty_pb2.Empty,
289 "response_type": account_pb2.GetAccountInfoRes,
290 }
292 # no creds, no go for secure APIs
293 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
294 with pytest.raises(grpc.RpcError) as e:
295 call_rpc(empty_pb2.Empty())
296 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
297 assert e.value.details() == "Unauthorized"
299 # can auth with cookie
300 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
301 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
302 assert res1.username == user.username
304 # can't auth with wrong cookie
305 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
306 with pytest.raises(grpc.RpcError) as e:
307 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),))
308 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
309 assert e.value.details() == "Unauthorized"
311 # can auth with api key
312 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
313 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
314 assert res2.username == user.username
316 # can't auth with wrong api key
317 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
318 with pytest.raises(grpc.RpcError) as e:
319 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),))
320 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
321 assert e.value.details() == "Unauthorized"
323 # can auth with grpc helper (they do the same as above)
324 comp_creds = grpc.composite_channel_credentials(
325 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
326 )
327 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
328 res3 = call_rpc(empty_pb2.Empty())
329 assert res3.username == user.username
331 # can't auth with both
332 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
333 with pytest.raises(grpc.RpcError) as e:
334 call_rpc(
335 empty_pb2.Empty(),
336 metadata=(
337 ("cookie", f"couchers-sesh={token}"),
338 ("authorization", f"Bearer {api_key}"),
339 ),
340 )
341 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
342 assert e.value.details() == 'Both "cookie" and "authorization" in request'
344 # malformed bearer
345 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
346 with pytest.raises(grpc.RpcError) as e:
347 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
348 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
349 assert e.value.details() == "Unauthorized"
352def test_tracing_interceptor_auth_cookies(db):
353 user, token = generate_user()
355 account = Account()
357 rpc_def = {
358 "rpc": account.GetAccountInfo,
359 "service_name": "org.couchers.api.account.Account",
360 "method_name": "GetAccountInfo",
361 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()],
362 "request_type": empty_pb2.Empty,
363 "response_type": account_pb2.GetAccountInfoRes,
364 }
366 # with cookies
367 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
368 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
369 assert res1.username == user.username
371 with session_scope() as session:
372 trace = session.execute(select(APICall)).scalar_one()
373 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
374 assert not trace.status_code
375 assert trace.user_id == user.id
376 assert not trace.is_api_key
377 assert len(trace.request) == 0
378 assert not trace.traceback
381def test_tracing_interceptor_auth_api_key(db):
382 super_user, super_token = generate_user(is_superuser=True)
383 user, token = generate_user()
385 with real_admin_session(super_token) as api:
386 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
388 with session_scope() as session:
389 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
390 api_key = api_session.token
392 account = Account()
394 rpc_def = {
395 "rpc": account.GetAccountInfo,
396 "service_name": "org.couchers.api.account.Account",
397 "method_name": "GetAccountInfo",
398 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()],
399 "request_type": empty_pb2.Empty,
400 "response_type": account_pb2.GetAccountInfoRes,
401 }
403 # with api key
404 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
405 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
406 assert res1.username == user.username
408 with session_scope() as session:
409 trace = session.execute(select(APICall)).scalar_one()
410 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
411 assert not trace.status_code
412 assert trace.user_id == user.id
413 assert trace.is_api_key
414 assert len(trace.request) == 0
415 assert not trace.traceback
418def test_auth_levels(db):
419 def TestRpc(request, context):
420 return empty_pb2.Empty()
422 def gen_args(service, method):
423 return {
424 "rpc": TestRpc,
425 "service_name": service,
426 "method_name": method,
427 "interceptors": [AuthValidatorInterceptor()],
428 "request_type": empty_pb2.Empty,
429 "response_type": empty_pb2.Empty,
430 }
432 # superuser
433 _, super_token = generate_user(is_superuser=True)
434 # normal user
435 _, normal_token = generate_user()
436 # jailed user
437 _, jailed_token = generate_user(accepted_tos=0)
438 # open user
439 open_token = ""
441 # pick some rpcs here with the right auth levels
442 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
443 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
444 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
445 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
447 # pairs to check
448 checks = [
449 # name, args, token, works?, code, message
450 # open token only works on open servicers
451 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
452 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
453 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
454 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
455 # jailed works on jailed and open
456 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
457 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
458 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
459 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
460 # normal works on all but admin
461 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
462 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
463 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
464 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
465 # superuser works on all
466 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
467 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
468 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
469 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
470 ]
472 for name, token, args, should_work, code, message in checks:
473 print(f"Testing (token x args) = ({name}), {should_work=}")
474 metadata = (("cookie", f"couchers-sesh={token}"),)
475 with interceptor_dummy_api(**args) as call_rpc:
476 if should_work:
477 call_rpc(empty_pb2.Empty(), metadata=metadata)
478 else:
479 with pytest.raises(grpc.RpcError) as e:
480 call_rpc(empty_pb2.Empty(), metadata=metadata)
481 assert e.value.code() == code
482 assert e.value.details() == message
484 # a non-existent RPC
485 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
487 with interceptor_dummy_api(**nonexistent) as call_rpc:
488 with pytest.raises(Exception) as e:
489 call_rpc(empty_pb2.Empty())
490 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED
491 assert e.value.details() == "API call does not exist. Please refresh and try again."
493 # an RPC without a service level
494 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
496 with interceptor_dummy_api(**invalid_args) as call_rpc:
497 with pytest.raises(Exception) as e:
498 call_rpc(empty_pb2.Empty())
499 assert e.value.code() == grpc.StatusCode.INTERNAL
500 assert e.value.details() == "Internal authentication error."