Coverage for src/tests/test_interceptors.py: 100%
281 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-04-16 15:13 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-04-16 15:13 +0000
1from concurrent import futures
2from contextlib import contextmanager
4import grpc
5import pytest
6from google.protobuf import empty_pb2
8from couchers import errors
9from couchers.crypto import random_hex
10from couchers.db import session_scope
11from couchers.interceptors import (
12 AuthValidatorInterceptor,
13 CookieInterceptor,
14 ErrorSanitizationInterceptor,
15 SessionInterceptor,
16 TracingInterceptor,
17)
18from couchers.metrics import servicer_duration_histogram
19from couchers.models import APICall, UserSession
20from couchers.servicers.account import Account
21from couchers.servicers.api import API
22from couchers.sql import couchers_select as select
23from proto import account_pb2, admin_pb2, api_pb2, auth_pb2
24from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa
27@pytest.fixture(autouse=True)
28def _(testconfig):
29 pass
32@contextmanager
33def interceptor_dummy_api(
34 rpc,
35 interceptors,
36 service_name="testing.Test",
37 method_name="TestRpc",
38 request_type=empty_pb2.Empty,
39 response_type=empty_pb2.Empty,
40 creds=None,
41):
42 with futures.ThreadPoolExecutor(1) as executor:
43 server = grpc.server(executor, interceptors=interceptors)
44 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
46 # manually add the handler
47 rpc_method_handlers = {
48 method_name: grpc.unary_unary_rpc_method_handler(
49 rpc,
50 request_deserializer=request_type.FromString,
51 response_serializer=response_type.SerializeToString,
52 )
53 }
54 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
55 server.add_generic_rpc_handlers((generic_handler,))
56 server.start()
58 try:
59 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
60 call_rpc = channel.unary_unary(
61 f"/{service_name}/{method_name}",
62 request_serializer=request_type.SerializeToString,
63 response_deserializer=response_type.FromString,
64 )
65 yield call_rpc
66 finally:
67 server.stop(None).wait()
70def _get_histogram_labels_value(method, logged_in, exception, code):
71 metrics = servicer_duration_histogram.collect()
72 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
73 histogram_counts = [
74 s
75 for s in servicer_histogram.samples
76 if s.name == "couchers_servicer_duration_seconds_count"
77 and s.labels["method"] == method
78 and s.labels["logged_in"] == logged_in
79 and s.labels["code"] == code
80 and s.labels["exception"] == exception
81 ]
82 if len(histogram_counts) == 0:
83 return 0
84 return histogram_counts[0].value
87def test_logging_interceptor_ok():
88 def TestRpc(request, context):
89 return empty_pb2.Empty()
91 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
92 call_rpc(empty_pb2.Empty())
95def test_logging_interceptor_all_ignored():
96 # error codes that should not be touched by the interceptor
97 pass_through_status_codes = [
98 # we can't abort with OK
99 # grpc.StatusCode.OK,
100 grpc.StatusCode.CANCELLED,
101 grpc.StatusCode.UNKNOWN,
102 grpc.StatusCode.INVALID_ARGUMENT,
103 grpc.StatusCode.DEADLINE_EXCEEDED,
104 grpc.StatusCode.NOT_FOUND,
105 grpc.StatusCode.ALREADY_EXISTS,
106 grpc.StatusCode.PERMISSION_DENIED,
107 grpc.StatusCode.UNAUTHENTICATED,
108 grpc.StatusCode.RESOURCE_EXHAUSTED,
109 grpc.StatusCode.FAILED_PRECONDITION,
110 grpc.StatusCode.ABORTED,
111 grpc.StatusCode.OUT_OF_RANGE,
112 grpc.StatusCode.UNIMPLEMENTED,
113 grpc.StatusCode.INTERNAL,
114 grpc.StatusCode.UNAVAILABLE,
115 grpc.StatusCode.DATA_LOSS,
116 ]
118 for status_code in pass_through_status_codes:
119 message = random_hex()
121 def TestRpc(request, context):
122 context.abort(status_code, message) # noqa: B023
124 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
125 with pytest.raises(grpc.RpcError) as e:
126 call_rpc(empty_pb2.Empty())
127 assert e.value.code() == status_code
128 assert e.value.details() == message
131def test_logging_interceptor_assertion():
132 def TestRpc(request, context):
133 raise AssertionError()
135 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
136 with pytest.raises(grpc.RpcError) as e:
137 call_rpc(empty_pb2.Empty())
138 assert e.value.code() == grpc.StatusCode.INTERNAL
139 assert e.value.details() == errors.UNKNOWN_ERROR
142def test_logging_interceptor_div0():
143 def TestRpc(request, context):
144 1 / 0 # noqa: B018
146 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
147 with pytest.raises(grpc.RpcError) as e:
148 call_rpc(empty_pb2.Empty())
149 assert e.value.code() == grpc.StatusCode.INTERNAL
150 assert e.value.details() == errors.UNKNOWN_ERROR
153def test_logging_interceptor_raise():
154 def TestRpc(request, context):
155 raise Exception()
157 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
158 with pytest.raises(grpc.RpcError) as e:
159 call_rpc(empty_pb2.Empty())
160 assert e.value.code() == grpc.StatusCode.INTERNAL
161 assert e.value.details() == errors.UNKNOWN_ERROR
164def test_logging_interceptor_raise_custom():
165 class _TestingException(Exception):
166 pass
168 def TestRpc(request, context):
169 raise _TestingException("This is a custom exception")
171 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
172 with pytest.raises(grpc.RpcError) as e:
173 call_rpc(empty_pb2.Empty())
174 assert e.value.code() == grpc.StatusCode.INTERNAL
175 assert e.value.details() == errors.UNKNOWN_ERROR
178def test_tracing_interceptor_ok_open(db):
179 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "")
181 def TestRpc(request, context):
182 return empty_pb2.Empty()
184 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc:
185 call_rpc(empty_pb2.Empty())
187 with session_scope() as session:
188 trace = session.execute(select(APICall)).scalar_one()
189 assert trace.method == "/testing.Test/TestRpc"
190 assert not trace.status_code
191 assert not trace.user_id
192 assert len(trace.request) == 0
193 assert len(trace.response) == 0
194 assert not trace.traceback
196 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") == val + 1
199def test_tracing_interceptor_sensitive(db):
200 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "")
202 def TestRpc(request, context):
203 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
205 with interceptor_dummy_api(
206 TestRpc,
207 interceptors=[TracingInterceptor()],
208 request_type=auth_pb2.SignupFlowReq,
209 response_type=auth_pb2.AuthReq,
210 ) as call_rpc:
211 call_rpc(
212 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
213 )
215 with session_scope() as session:
216 trace = session.execute(select(APICall)).scalar_one()
217 assert trace.method == "/testing.Test/TestRpc"
218 assert not trace.status_code
219 assert not trace.user_id
220 assert not trace.traceback
221 req = auth_pb2.SignupFlowReq.FromString(trace.request)
222 assert not req.account.password
223 assert req.account.username == "not removed"
224 res = auth_pb2.AuthReq.FromString(trace.response)
225 assert res.user == "this is not secret"
226 assert not res.password
228 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") == val + 1
231def test_tracing_interceptor_sensitive_ping(db):
232 user, token = generate_user()
234 with interceptor_dummy_api(
235 API().GetUser,
236 interceptors=[TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
237 request_type=api_pb2.GetUserReq,
238 response_type=api_pb2.User,
239 service_name="org.couchers.api.core.API",
240 method_name="GetUser",
241 ) as call_rpc:
242 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
245def test_tracing_interceptor_exception(db):
246 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "")
248 def TestRpc(request, context):
249 raise Exception("Some error message")
251 with interceptor_dummy_api(
252 TestRpc,
253 interceptors=[TracingInterceptor()],
254 request_type=auth_pb2.SignupAccount,
255 response_type=auth_pb2.AuthReq,
256 ) as call_rpc:
257 with pytest.raises(Exception, match="Some error message"):
258 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
260 with session_scope() as session:
261 trace = session.execute(select(APICall)).scalar_one()
262 assert trace.method == "/testing.Test/TestRpc"
263 assert not trace.status_code
264 assert not trace.user_id
265 assert "Some error message" in trace.traceback
266 req = auth_pb2.SignupAccount.FromString(trace.request)
267 assert not req.password
268 assert req.username == "not removed"
269 assert not trace.response
271 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "") == val + 1
274def test_tracing_interceptor_abort(db):
275 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION")
277 def TestRpc(request, context):
278 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
280 with interceptor_dummy_api(
281 TestRpc,
282 interceptors=[TracingInterceptor()],
283 request_type=auth_pb2.SignupAccount,
284 response_type=auth_pb2.AuthReq,
285 ) as call_rpc:
286 with pytest.raises(Exception, match="now a grpc abort"):
287 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
289 with session_scope() as session:
290 trace = session.execute(select(APICall)).scalar_one()
291 assert trace.method == "/testing.Test/TestRpc"
292 assert trace.status_code == "FAILED_PRECONDITION"
293 assert not trace.user_id
294 assert "now a grpc abort" in trace.traceback
295 req = auth_pb2.SignupAccount.FromString(trace.request)
296 assert not req.password
297 assert req.username == "not removed"
298 assert not trace.response
300 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION") == val + 1
303def test_auth_interceptor(db):
304 super_user, super_token = generate_user(is_superuser=True)
305 user, token = generate_user()
307 with real_admin_session(super_token) as api:
308 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
310 with session_scope() as session:
311 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
312 api_key = api_session.token
314 account = Account()
316 rpc_def = {
317 "rpc": account.GetAccountInfo,
318 "service_name": "org.couchers.api.account.Account",
319 "method_name": "GetAccountInfo",
320 "interceptors": [AuthValidatorInterceptor(), CookieInterceptor(), SessionInterceptor()],
321 "request_type": empty_pb2.Empty,
322 "response_type": account_pb2.GetAccountInfoRes,
323 }
325 # no creds, no go for secure APIs
326 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
327 with pytest.raises(grpc.RpcError) as e:
328 call_rpc(empty_pb2.Empty())
329 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
330 assert e.value.details() == "Unauthorized"
332 # can auth with cookie
333 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
334 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
335 assert res1.username == user.username
337 # can't auth with wrong cookie
338 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
339 with pytest.raises(grpc.RpcError) as e:
340 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),))
341 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
342 assert e.value.details() == "Unauthorized"
344 # can auth with api key
345 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
346 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
347 assert res2.username == user.username
349 # can't auth with wrong api key
350 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
351 with pytest.raises(grpc.RpcError) as e:
352 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),))
353 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
354 assert e.value.details() == "Unauthorized"
356 # can auth with grpc helper (they do the same as above)
357 comp_creds = grpc.composite_channel_credentials(
358 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
359 )
360 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
361 res3 = call_rpc(empty_pb2.Empty())
362 assert res3.username == user.username
364 # can't auth with both
365 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
366 with pytest.raises(grpc.RpcError) as e:
367 call_rpc(
368 empty_pb2.Empty(),
369 metadata=(
370 ("cookie", f"couchers-sesh={token}"),
371 ("authorization", f"Bearer {api_key}"),
372 ),
373 )
374 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
375 assert e.value.details() == 'Both "cookie" and "authorization" in request'
377 # malformed bearer
378 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
379 with pytest.raises(grpc.RpcError) as e:
380 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
381 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
382 assert e.value.details() == "Unauthorized"
385def test_tracing_interceptor_auth_cookies(db):
386 user, token = generate_user()
388 account = Account()
390 rpc_def = {
391 "rpc": account.GetAccountInfo,
392 "service_name": "org.couchers.api.account.Account",
393 "method_name": "GetAccountInfo",
394 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
395 "request_type": empty_pb2.Empty,
396 "response_type": account_pb2.GetAccountInfoRes,
397 }
399 # with cookies
400 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
401 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
402 assert res1.username == user.username
404 with session_scope() as session:
405 trace = session.execute(select(APICall)).scalar_one()
406 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
407 assert not trace.status_code
408 assert trace.user_id == user.id
409 assert not trace.is_api_key
410 assert len(trace.request) == 0
411 assert not trace.traceback
414def test_tracing_interceptor_auth_api_key(db):
415 super_user, super_token = generate_user(is_superuser=True)
416 user, token = generate_user()
418 with real_admin_session(super_token) as api:
419 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
421 with session_scope() as session:
422 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
423 api_key = api_session.token
425 account = Account()
427 rpc_def = {
428 "rpc": account.GetAccountInfo,
429 "service_name": "org.couchers.api.account.Account",
430 "method_name": "GetAccountInfo",
431 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
432 "request_type": empty_pb2.Empty,
433 "response_type": account_pb2.GetAccountInfoRes,
434 }
436 # with api key
437 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
438 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
439 assert res1.username == user.username
441 with session_scope() as session:
442 trace = session.execute(select(APICall)).scalar_one()
443 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
444 assert not trace.status_code
445 assert trace.user_id == user.id
446 assert trace.is_api_key
447 assert len(trace.request) == 0
448 assert not trace.traceback
451def test_auth_levels(db):
452 def TestRpc(request, context):
453 return empty_pb2.Empty()
455 def gen_args(service, method):
456 return {
457 "rpc": TestRpc,
458 "service_name": service,
459 "method_name": method,
460 "interceptors": [AuthValidatorInterceptor()],
461 "request_type": empty_pb2.Empty,
462 "response_type": empty_pb2.Empty,
463 }
465 # superuser
466 _, super_token = generate_user(is_superuser=True)
467 # normal user
468 _, normal_token = generate_user()
469 # jailed user
470 _, jailed_token = generate_user(accepted_tos=0)
471 # open user
472 open_token = ""
474 # pick some rpcs here with the right auth levels
475 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
476 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
477 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
478 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
480 # pairs to check
481 checks = [
482 # name, args, token, works?, code, message
483 # open token only works on open servicers
484 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
485 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
486 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
487 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
488 # jailed works on jailed and open
489 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
490 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
491 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
492 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
493 # normal works on all but admin
494 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
495 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
496 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
497 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
498 # superuser works on all
499 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
500 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
501 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
502 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
503 ]
505 for name, token, args, should_work, code, message in checks:
506 print(f"Testing (token x args) = ({name}), {should_work=}")
507 metadata = (("cookie", f"couchers-sesh={token}"),)
508 with interceptor_dummy_api(**args) as call_rpc:
509 if should_work:
510 call_rpc(empty_pb2.Empty(), metadata=metadata)
511 else:
512 with pytest.raises(grpc.RpcError) as e:
513 call_rpc(empty_pb2.Empty(), metadata=metadata)
514 assert e.value.code() == code
515 assert e.value.details() == message
517 # a non-existent RPC
518 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
520 with interceptor_dummy_api(**nonexistent) as call_rpc:
521 with pytest.raises(Exception) as e:
522 call_rpc(empty_pb2.Empty())
523 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED
524 assert e.value.details() == "API call does not exist. Please refresh and try again."
526 # an RPC without a service level
527 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
529 with interceptor_dummy_api(**invalid_args) as call_rpc:
530 with pytest.raises(Exception) as e:
531 call_rpc(empty_pb2.Empty())
532 assert e.value.code() == grpc.StatusCode.INTERNAL
533 assert e.value.details() == "Internal authentication error."