Coverage for src/tests/test_interceptors.py: 100%
280 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-28 14:55 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-28 14:55 +0000
1from concurrent import futures
2from contextlib import contextmanager
4import grpc
5import pytest
6from google.protobuf import empty_pb2
8from couchers import errors
9from couchers.crypto import random_hex
10from couchers.db import session_scope
11from couchers.interceptors import (
12 CouchersMiddlewareInterceptor,
13 ErrorSanitizationInterceptor,
14)
15from couchers.metrics import servicer_duration_histogram
16from couchers.models import APICall, UserSession
17from couchers.servicers.account import Account
18from couchers.servicers.api import API
19from couchers.sql import couchers_select as select
20from proto import account_pb2, admin_pb2, api_pb2, auth_pb2
21from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa
24@pytest.fixture(autouse=True)
25def _(testconfig):
26 pass
29@contextmanager
30def interceptor_dummy_api(
31 rpc,
32 interceptors,
33 service_name="org.couchers.auth.Auth",
34 method_name="SignupFlow",
35 request_type=empty_pb2.Empty,
36 response_type=empty_pb2.Empty,
37 creds=None,
38):
39 with futures.ThreadPoolExecutor(1) as executor:
40 server = grpc.server(executor, interceptors=interceptors)
41 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
43 # manually add the handler
44 rpc_method_handlers = {
45 method_name: grpc.unary_unary_rpc_method_handler(
46 rpc,
47 request_deserializer=request_type.FromString,
48 response_serializer=response_type.SerializeToString,
49 )
50 }
51 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
52 server.add_generic_rpc_handlers((generic_handler,))
53 server.start()
55 try:
56 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
57 yield channel.unary_unary(
58 f"/{service_name}/{method_name}",
59 request_serializer=request_type.SerializeToString,
60 response_deserializer=response_type.FromString,
61 )
62 finally:
63 server.stop(None).wait()
66def _get_histogram_labels_value(method, logged_in, exception, code):
67 metrics = servicer_duration_histogram.collect()
68 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
69 histogram_counts = [
70 s
71 for s in servicer_histogram.samples
72 if s.name == "couchers_servicer_duration_seconds_count"
73 and s.labels["method"] == method
74 and s.labels["logged_in"] == logged_in
75 and s.labels["code"] == code
76 and s.labels["exception"] == exception
77 ]
78 if len(histogram_counts) == 0:
79 return 0
80 return histogram_counts[0].value
83def test_logging_interceptor_ok():
84 def TestRpc(request, context):
85 return empty_pb2.Empty()
87 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
88 call_rpc(empty_pb2.Empty())
91def test_logging_interceptor_all_ignored():
92 # error codes that should not be touched by the interceptor
93 pass_through_status_codes = [
94 # we can't abort with OK
95 # grpc.StatusCode.OK,
96 grpc.StatusCode.CANCELLED,
97 grpc.StatusCode.UNKNOWN,
98 grpc.StatusCode.INVALID_ARGUMENT,
99 grpc.StatusCode.DEADLINE_EXCEEDED,
100 grpc.StatusCode.NOT_FOUND,
101 grpc.StatusCode.ALREADY_EXISTS,
102 grpc.StatusCode.PERMISSION_DENIED,
103 grpc.StatusCode.UNAUTHENTICATED,
104 grpc.StatusCode.RESOURCE_EXHAUSTED,
105 grpc.StatusCode.FAILED_PRECONDITION,
106 grpc.StatusCode.ABORTED,
107 grpc.StatusCode.OUT_OF_RANGE,
108 grpc.StatusCode.UNIMPLEMENTED,
109 grpc.StatusCode.INTERNAL,
110 grpc.StatusCode.UNAVAILABLE,
111 grpc.StatusCode.DATA_LOSS,
112 ]
114 for status_code in pass_through_status_codes:
115 message = random_hex()
117 def TestRpc(request, context):
118 context.abort(status_code, message) # noqa: B023
120 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
121 with pytest.raises(grpc.RpcError) as e:
122 call_rpc(empty_pb2.Empty())
123 assert e.value.code() == status_code
124 assert e.value.details() == message
127def test_logging_interceptor_assertion():
128 def TestRpc(request, context):
129 raise AssertionError()
131 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
132 with pytest.raises(grpc.RpcError) as e:
133 call_rpc(empty_pb2.Empty())
134 assert e.value.code() == grpc.StatusCode.INTERNAL
135 assert e.value.details() == errors.UNKNOWN_ERROR
138def test_logging_interceptor_div0():
139 def TestRpc(request, context):
140 1 / 0 # noqa: B018
142 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
143 with pytest.raises(grpc.RpcError) as e:
144 call_rpc(empty_pb2.Empty())
145 assert e.value.code() == grpc.StatusCode.INTERNAL
146 assert e.value.details() == errors.UNKNOWN_ERROR
149def test_logging_interceptor_raise():
150 def TestRpc(request, context):
151 raise Exception()
153 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
154 with pytest.raises(grpc.RpcError) as e:
155 call_rpc(empty_pb2.Empty())
156 assert e.value.code() == grpc.StatusCode.INTERNAL
157 assert e.value.details() == errors.UNKNOWN_ERROR
160def test_logging_interceptor_raise_custom():
161 class _TestingException(Exception):
162 pass
164 def TestRpc(request, context):
165 raise _TestingException("This is a custom exception")
167 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
168 with pytest.raises(grpc.RpcError) as e:
169 call_rpc(empty_pb2.Empty())
170 assert e.value.code() == grpc.StatusCode.INTERNAL
171 assert e.value.details() == errors.UNKNOWN_ERROR
174def test_tracing_interceptor_ok_open(db):
175 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
177 def TestRpc(request, context, session):
178 return empty_pb2.Empty()
180 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
181 call_rpc(empty_pb2.Empty())
183 with session_scope() as session:
184 trace = session.execute(select(APICall)).scalar_one()
185 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
186 assert not trace.status_code
187 assert not trace.user_id
188 assert len(trace.request) == 0
189 assert len(trace.response) == 0
190 assert not trace.traceback
192 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
195def test_tracing_interceptor_sensitive(db):
196 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
198 def TestRpc(request, context, session):
199 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
201 with interceptor_dummy_api(
202 TestRpc,
203 interceptors=[CouchersMiddlewareInterceptor()],
204 request_type=auth_pb2.SignupFlowReq,
205 response_type=auth_pb2.AuthReq,
206 ) as call_rpc:
207 call_rpc(
208 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
209 )
211 with session_scope() as session:
212 trace = session.execute(select(APICall)).scalar_one()
213 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
214 assert not trace.status_code
215 assert not trace.user_id
216 assert not trace.traceback
217 req = auth_pb2.SignupFlowReq.FromString(trace.request)
218 assert not req.account.password
219 assert req.account.username == "not removed"
220 res = auth_pb2.AuthReq.FromString(trace.response)
221 assert res.user == "this is not secret"
222 assert not res.password
224 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
227def test_tracing_interceptor_sensitive_ping(db):
228 user, token = generate_user()
230 with interceptor_dummy_api(
231 API().GetUser,
232 interceptors=[CouchersMiddlewareInterceptor()],
233 request_type=api_pb2.GetUserReq,
234 response_type=api_pb2.User,
235 service_name="org.couchers.api.core.API",
236 method_name="GetUser",
237 ) as call_rpc:
238 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
241def test_tracing_interceptor_exception(db):
242 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "")
244 def TestRpc(request, context, session):
245 raise Exception("Some error message")
247 with interceptor_dummy_api(
248 TestRpc,
249 interceptors=[CouchersMiddlewareInterceptor()],
250 request_type=auth_pb2.SignupAccount,
251 response_type=auth_pb2.AuthReq,
252 ) as call_rpc:
253 with pytest.raises(Exception, match="Some error message"):
254 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
256 with session_scope() as session:
257 trace = session.execute(select(APICall)).scalar_one()
258 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
259 assert not trace.status_code
260 assert not trace.user_id
261 assert "Some error message" in trace.traceback
262 req = auth_pb2.SignupAccount.FromString(trace.request)
263 assert not req.password
264 assert req.username == "not removed"
265 assert not trace.response
267 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1
270def test_tracing_interceptor_abort(db):
271 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
273 def TestRpc(request, context, session):
274 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
276 with interceptor_dummy_api(
277 TestRpc,
278 interceptors=[CouchersMiddlewareInterceptor()],
279 request_type=auth_pb2.SignupAccount,
280 response_type=auth_pb2.AuthReq,
281 ) as call_rpc:
282 with pytest.raises(Exception, match="now a grpc abort"):
283 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
285 with session_scope() as session:
286 trace = session.execute(select(APICall)).scalar_one()
287 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
288 assert trace.status_code == "FAILED_PRECONDITION"
289 assert not trace.user_id
290 assert "now a grpc abort" in trace.traceback
291 req = auth_pb2.SignupAccount.FromString(trace.request)
292 assert not req.password
293 assert req.username == "not removed"
294 assert not trace.response
296 assert (
297 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
298 == val + 1
299 )
302def test_auth_interceptor(db):
303 super_user, super_token = generate_user(is_superuser=True)
304 user, token = generate_user()
306 with real_admin_session(super_token) as api:
307 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
309 with session_scope() as session:
310 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
311 api_key = api_session.token
313 account = Account()
315 rpc_def = {
316 "rpc": account.GetAccountInfo,
317 "service_name": "org.couchers.api.account.Account",
318 "method_name": "GetAccountInfo",
319 "interceptors": [CouchersMiddlewareInterceptor()],
320 "request_type": empty_pb2.Empty,
321 "response_type": account_pb2.GetAccountInfoRes,
322 }
324 # no creds, no go for secure APIs
325 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
326 with pytest.raises(grpc.RpcError) as e:
327 call_rpc(empty_pb2.Empty())
328 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
329 assert e.value.details() == "Unauthorized"
331 # can auth with cookie
332 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
333 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
334 assert res1.username == user.username
336 # can't auth with wrong cookie
337 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
338 with pytest.raises(grpc.RpcError) as e:
339 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),))
340 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
341 assert e.value.details() == "Unauthorized"
343 # can auth with api key
344 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
345 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
346 assert res2.username == user.username
348 # can't auth with wrong api key
349 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
350 with pytest.raises(grpc.RpcError) as e:
351 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),))
352 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
353 assert e.value.details() == "Unauthorized"
355 # can auth with grpc helper (they do the same as above)
356 comp_creds = grpc.composite_channel_credentials(
357 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
358 )
359 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
360 res3 = call_rpc(empty_pb2.Empty())
361 assert res3.username == user.username
363 # can't auth with both
364 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
365 with pytest.raises(grpc.RpcError) as e:
366 call_rpc(
367 empty_pb2.Empty(),
368 metadata=(
369 ("cookie", f"couchers-sesh={token}"),
370 ("authorization", f"Bearer {api_key}"),
371 ),
372 )
373 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
374 assert e.value.details() == 'Both "cookie" and "authorization" in request'
376 # malformed bearer
377 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
378 with pytest.raises(grpc.RpcError) as e:
379 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
380 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
381 assert e.value.details() == "Unauthorized"
384def test_tracing_interceptor_auth_cookies(db):
385 user, token = generate_user()
387 account = Account()
389 rpc_def = {
390 "rpc": account.GetAccountInfo,
391 "service_name": "org.couchers.api.account.Account",
392 "method_name": "GetAccountInfo",
393 "interceptors": [CouchersMiddlewareInterceptor()],
394 "request_type": empty_pb2.Empty,
395 "response_type": account_pb2.GetAccountInfoRes,
396 }
398 # with cookies
399 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
400 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
401 assert res1.username == user.username
403 with session_scope() as session:
404 trace = session.execute(select(APICall)).scalar_one()
405 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
406 assert not trace.status_code
407 assert trace.user_id == user.id
408 assert not trace.is_api_key
409 assert len(trace.request) == 0
410 assert not trace.traceback
413def test_tracing_interceptor_auth_api_key(db):
414 super_user, super_token = generate_user(is_superuser=True)
415 user, token = generate_user()
417 with real_admin_session(super_token) as api:
418 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
420 with session_scope() as session:
421 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
422 api_key = api_session.token
424 account = Account()
426 rpc_def = {
427 "rpc": account.GetAccountInfo,
428 "service_name": "org.couchers.api.account.Account",
429 "method_name": "GetAccountInfo",
430 "interceptors": [CouchersMiddlewareInterceptor()],
431 "request_type": empty_pb2.Empty,
432 "response_type": account_pb2.GetAccountInfoRes,
433 }
435 # with api key
436 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
437 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
438 assert res1.username == user.username
440 with session_scope() as session:
441 trace = session.execute(
442 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo")
443 ).scalar_one()
444 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
445 assert not trace.status_code
446 assert trace.user_id == user.id
447 assert trace.is_api_key
448 assert len(trace.request) == 0
449 assert not trace.traceback
452def test_auth_levels(db):
453 def TestRpc(request, context, session):
454 return empty_pb2.Empty()
456 def gen_args(service, method):
457 return {
458 "rpc": TestRpc,
459 "service_name": service,
460 "method_name": method,
461 "interceptors": [CouchersMiddlewareInterceptor()],
462 "request_type": empty_pb2.Empty,
463 "response_type": empty_pb2.Empty,
464 }
466 # superuser
467 _, super_token = generate_user(is_superuser=True)
468 # normal user
469 _, normal_token = generate_user()
470 # jailed user
471 _, jailed_token = generate_user(accepted_tos=0)
472 # open user
473 open_token = ""
475 # pick some rpcs here with the right auth levels
476 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
477 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
478 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
479 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
481 # pairs to check
482 checks = [
483 # name, args, token, works?, code, message
484 # open token only works on open servicers
485 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
486 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
487 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
488 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
489 # jailed works on jailed and open
490 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
491 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
492 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
493 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
494 # normal works on all but admin
495 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
496 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
497 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
498 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
499 # superuser works on all
500 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
501 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
502 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
503 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
504 ]
506 for name, token, args, should_work, code, message in checks:
507 print(f"Testing (token x args) = ({name}), {should_work=}")
508 metadata = (("cookie", f"couchers-sesh={token}"),)
509 with interceptor_dummy_api(**args) as call_rpc:
510 if should_work:
511 call_rpc(empty_pb2.Empty(), metadata=metadata)
512 else:
513 with pytest.raises(grpc.RpcError) as e:
514 call_rpc(empty_pb2.Empty(), metadata=metadata)
515 assert e.value.code() == code
516 assert e.value.details() == message
518 # a non-existent RPC
519 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
521 with interceptor_dummy_api(**nonexistent) as call_rpc:
522 with pytest.raises(Exception) as e:
523 call_rpc(empty_pb2.Empty())
524 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED
525 assert e.value.details() == "API call does not exist. Please refresh and try again."
527 # an RPC without a service level
528 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
530 with interceptor_dummy_api(**invalid_args) as call_rpc:
531 with pytest.raises(Exception) as e:
532 call_rpc(empty_pb2.Empty())
533 assert e.value.code() == grpc.StatusCode.INTERNAL
534 assert e.value.details() == "Internal authentication error."