Coverage for src/tests/test_interceptors.py: 100%
281 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-06 23:17 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-06 23:17 +0000
1from concurrent import futures
2from contextlib import contextmanager
4import grpc
5import pytest
6from google.protobuf import empty_pb2
8from couchers.crypto import random_hex
9from couchers.db import session_scope
10from couchers.interceptors import (
11 CouchersMiddlewareInterceptor,
12 ErrorSanitizationInterceptor,
13)
14from couchers.metrics import servicer_duration_histogram
15from couchers.models import APICall, UserSession
16from couchers.proto import account_pb2, admin_pb2, api_pb2, auth_pb2
17from couchers.servicers.account import Account
18from couchers.servicers.api import API
19from couchers.sql import couchers_select as select
20from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa
23@pytest.fixture(autouse=True)
24def _(testconfig):
25 pass
28@contextmanager
29def interceptor_dummy_api(
30 rpc,
31 interceptors,
32 service_name="org.couchers.auth.Auth",
33 method_name="SignupFlow",
34 request_type=empty_pb2.Empty,
35 response_type=empty_pb2.Empty,
36 creds=None,
37):
38 with futures.ThreadPoolExecutor(1) as executor:
39 server = grpc.server(executor, interceptors=interceptors)
40 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
42 # manually add the handler
43 rpc_method_handlers = {
44 method_name: grpc.unary_unary_rpc_method_handler(
45 rpc,
46 request_deserializer=request_type.FromString,
47 response_serializer=response_type.SerializeToString,
48 )
49 }
50 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
51 server.add_generic_rpc_handlers((generic_handler,))
52 server.start()
54 try:
55 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
56 yield channel.unary_unary(
57 f"/{service_name}/{method_name}",
58 request_serializer=request_type.SerializeToString,
59 response_deserializer=response_type.FromString,
60 )
61 finally:
62 server.stop(None).wait()
65def _get_histogram_labels_value(method, logged_in, exception, code):
66 metrics = servicer_duration_histogram.collect()
67 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
68 histogram_counts = [
69 s
70 for s in servicer_histogram.samples
71 if s.name == "couchers_servicer_duration_seconds_count"
72 and s.labels["method"] == method
73 and s.labels["logged_in"] == logged_in
74 and s.labels["code"] == code
75 and s.labels["exception"] == exception
76 ]
77 if len(histogram_counts) == 0:
78 return 0
79 return histogram_counts[0].value
82def test_logging_interceptor_ok():
83 def TestRpc(request, context):
84 return empty_pb2.Empty()
86 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
87 call_rpc(empty_pb2.Empty())
90def test_logging_interceptor_all_ignored():
91 # error codes that should not be touched by the interceptor
92 pass_through_status_codes = [
93 # we can't abort with OK
94 # grpc.StatusCode.OK,
95 grpc.StatusCode.CANCELLED,
96 grpc.StatusCode.UNKNOWN,
97 grpc.StatusCode.INVALID_ARGUMENT,
98 grpc.StatusCode.DEADLINE_EXCEEDED,
99 grpc.StatusCode.NOT_FOUND,
100 grpc.StatusCode.ALREADY_EXISTS,
101 grpc.StatusCode.PERMISSION_DENIED,
102 grpc.StatusCode.UNAUTHENTICATED,
103 grpc.StatusCode.RESOURCE_EXHAUSTED,
104 grpc.StatusCode.FAILED_PRECONDITION,
105 grpc.StatusCode.ABORTED,
106 grpc.StatusCode.OUT_OF_RANGE,
107 grpc.StatusCode.UNIMPLEMENTED,
108 grpc.StatusCode.INTERNAL,
109 grpc.StatusCode.UNAVAILABLE,
110 grpc.StatusCode.DATA_LOSS,
111 ]
113 for status_code in pass_through_status_codes:
114 message = random_hex()
116 def TestRpc(request, context):
117 context.abort(status_code, message) # noqa: B023
119 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
120 with pytest.raises(grpc.RpcError) as e:
121 call_rpc(empty_pb2.Empty())
122 assert e.value.code() == status_code
123 assert e.value.details() == message
126def test_logging_interceptor_assertion():
127 def TestRpc(request, context):
128 raise AssertionError()
130 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
131 with pytest.raises(grpc.RpcError) as e:
132 call_rpc(empty_pb2.Empty())
133 assert e.value.code() == grpc.StatusCode.INTERNAL
134 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
137def test_logging_interceptor_div0():
138 def TestRpc(request, context):
139 1 / 0 # noqa: B018
141 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
142 with pytest.raises(grpc.RpcError) as e:
143 call_rpc(empty_pb2.Empty())
144 assert e.value.code() == grpc.StatusCode.INTERNAL
145 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
148def test_logging_interceptor_raise():
149 def TestRpc(request, context):
150 raise Exception()
152 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
153 with pytest.raises(grpc.RpcError) as e:
154 call_rpc(empty_pb2.Empty())
155 assert e.value.code() == grpc.StatusCode.INTERNAL
156 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
159def test_logging_interceptor_raise_custom():
160 class _TestingException(Exception):
161 pass
163 def TestRpc(request, context):
164 raise _TestingException("This is a custom exception")
166 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
167 with pytest.raises(grpc.RpcError) as e:
168 call_rpc(empty_pb2.Empty())
169 assert e.value.code() == grpc.StatusCode.INTERNAL
170 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
173def test_tracing_interceptor_ok_open(db):
174 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
176 def TestRpc(request, context, session):
177 return empty_pb2.Empty()
179 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
180 call_rpc(empty_pb2.Empty())
182 with session_scope() as session:
183 trace = session.execute(select(APICall)).scalar_one()
184 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
185 assert not trace.status_code
186 assert not trace.user_id
187 assert len(trace.request) == 0
188 assert len(trace.response) == 0
189 assert not trace.traceback
191 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
194def test_tracing_interceptor_sensitive(db):
195 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
197 def TestRpc(request, context, session):
198 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
200 with interceptor_dummy_api(
201 TestRpc,
202 interceptors=[CouchersMiddlewareInterceptor()],
203 request_type=auth_pb2.SignupFlowReq,
204 response_type=auth_pb2.AuthReq,
205 ) as call_rpc:
206 call_rpc(
207 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
208 )
210 with session_scope() as session:
211 trace = session.execute(select(APICall)).scalar_one()
212 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
213 assert not trace.status_code
214 assert not trace.user_id
215 assert not trace.traceback
216 req = auth_pb2.SignupFlowReq.FromString(trace.request)
217 assert not req.account.password
218 assert req.account.username == "not removed"
219 res = auth_pb2.AuthReq.FromString(trace.response)
220 assert res.user == "this is not secret"
221 assert not res.password
223 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
226def test_tracing_interceptor_sensitive_ping(db):
227 user, token = generate_user()
229 with interceptor_dummy_api(
230 API().GetUser,
231 interceptors=[CouchersMiddlewareInterceptor()],
232 request_type=api_pb2.GetUserReq,
233 response_type=api_pb2.User,
234 service_name="org.couchers.api.core.API",
235 method_name="GetUser",
236 ) as call_rpc:
237 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
240def test_tracing_interceptor_exception(db):
241 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "")
243 def TestRpc(request, context, session):
244 raise Exception("Some error message")
246 with interceptor_dummy_api(
247 TestRpc,
248 interceptors=[CouchersMiddlewareInterceptor()],
249 request_type=auth_pb2.SignupAccount,
250 response_type=auth_pb2.AuthReq,
251 ) as call_rpc:
252 with pytest.raises(Exception, match="Some error message"):
253 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
255 with session_scope() as session:
256 trace = session.execute(select(APICall)).scalar_one()
257 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
258 assert not trace.status_code
259 assert not trace.user_id
260 assert "Some error message" in trace.traceback
261 req = auth_pb2.SignupAccount.FromString(trace.request)
262 assert not req.password
263 assert req.username == "not removed"
264 assert not trace.response
266 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1
269def test_tracing_interceptor_abort(db):
270 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
272 def TestRpc(request, context, session):
273 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
275 with interceptor_dummy_api(
276 TestRpc,
277 interceptors=[CouchersMiddlewareInterceptor()],
278 request_type=auth_pb2.SignupAccount,
279 response_type=auth_pb2.AuthReq,
280 ) as call_rpc:
281 with pytest.raises(Exception, match="now a grpc abort"):
282 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
284 with session_scope() as session:
285 trace = session.execute(select(APICall)).scalar_one()
286 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
287 assert trace.status_code == "FAILED_PRECONDITION"
288 assert not trace.user_id
289 assert "now a grpc abort" in trace.traceback
290 req = auth_pb2.SignupAccount.FromString(trace.request)
291 assert not req.password
292 assert req.username == "not removed"
293 assert not trace.response
295 assert (
296 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
297 == val + 1
298 )
301def test_auth_interceptor(db):
302 super_user, super_token = generate_user(is_superuser=True)
303 user, token = generate_user()
305 with real_admin_session(super_token) as api:
306 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
308 with session_scope() as session:
309 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
310 api_key = api_session.token
312 account = Account()
314 rpc_def = {
315 "rpc": account.GetAccountInfo,
316 "service_name": "org.couchers.api.account.Account",
317 "method_name": "GetAccountInfo",
318 "interceptors": [CouchersMiddlewareInterceptor()],
319 "request_type": empty_pb2.Empty,
320 "response_type": account_pb2.GetAccountInfoRes,
321 }
323 # no creds, no go for secure APIs
324 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
325 with pytest.raises(grpc.RpcError) as e:
326 call_rpc(empty_pb2.Empty())
327 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
328 assert e.value.details() == "Unauthorized"
330 # can auth with cookie
331 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
332 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
333 assert res1.username == user.username
335 # can't auth with wrong cookie
336 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
337 with pytest.raises(grpc.RpcError) as e:
338 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),))
339 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
340 assert e.value.details() == "Unauthorized"
342 # can auth with api key
343 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
344 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
345 assert res2.username == user.username
347 # can't auth with wrong api key
348 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
349 with pytest.raises(grpc.RpcError) as e:
350 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),))
351 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
352 assert e.value.details() == "Unauthorized"
354 # can auth with grpc helper (they do the same as above)
355 comp_creds = grpc.composite_channel_credentials(
356 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
357 )
358 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
359 res3 = call_rpc(empty_pb2.Empty())
360 assert res3.username == user.username
362 # can't auth with both
363 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
364 with pytest.raises(grpc.RpcError) as e:
365 call_rpc(
366 empty_pb2.Empty(),
367 metadata=(
368 ("cookie", f"couchers-sesh={token}"),
369 ("authorization", f"Bearer {api_key}"),
370 ),
371 )
372 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
373 assert e.value.details() == 'Both "cookie" and "authorization" in request'
375 # malformed bearer
376 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
377 with pytest.raises(grpc.RpcError) as e:
378 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
379 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
380 assert e.value.details() == "Unauthorized"
383def test_tracing_interceptor_auth_cookies(db):
384 user, token = generate_user()
386 account = Account()
388 rpc_def = {
389 "rpc": account.GetAccountInfo,
390 "service_name": "org.couchers.api.account.Account",
391 "method_name": "GetAccountInfo",
392 "interceptors": [CouchersMiddlewareInterceptor()],
393 "request_type": empty_pb2.Empty,
394 "response_type": account_pb2.GetAccountInfoRes,
395 }
397 # with cookies
398 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
399 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
400 assert res1.username == user.username
402 with session_scope() as session:
403 trace = session.execute(select(APICall)).scalar_one()
404 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
405 assert not trace.status_code
406 assert trace.user_id == user.id
407 assert not trace.is_api_key
408 assert len(trace.request) == 0
409 assert not trace.traceback
412def test_tracing_interceptor_auth_api_key(db):
413 super_user, super_token = generate_user(is_superuser=True)
414 user, token = generate_user()
416 with real_admin_session(super_token) as api:
417 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
419 with session_scope() as session:
420 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
421 api_key = api_session.token
423 account = Account()
425 rpc_def = {
426 "rpc": account.GetAccountInfo,
427 "service_name": "org.couchers.api.account.Account",
428 "method_name": "GetAccountInfo",
429 "interceptors": [CouchersMiddlewareInterceptor()],
430 "request_type": empty_pb2.Empty,
431 "response_type": account_pb2.GetAccountInfoRes,
432 }
434 # with api key
435 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
436 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
437 assert res1.username == user.username
439 with session_scope() as session:
440 trace = session.execute(
441 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo")
442 ).scalar_one()
443 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
444 assert not trace.status_code
445 assert trace.user_id == user.id
446 assert trace.is_api_key
447 assert len(trace.request) == 0
448 assert not trace.traceback
451def test_auth_levels(db):
452 def TestRpc(request, context, session):
453 return empty_pb2.Empty()
455 def gen_args(service, method):
456 return {
457 "rpc": TestRpc,
458 "service_name": service,
459 "method_name": method,
460 "interceptors": [CouchersMiddlewareInterceptor()],
461 "request_type": empty_pb2.Empty,
462 "response_type": empty_pb2.Empty,
463 }
465 # superuser (note: superusers are automatically editors due to DB constraint)
466 _, super_token = generate_user(is_superuser=True)
467 # editor user
468 _, editor_token = generate_user(is_editor=True)
469 # normal user
470 _, normal_token = generate_user()
471 # jailed user
472 _, jailed_token = generate_user(accepted_tos=0)
473 # open user
474 open_token = ""
476 # pick some rpcs here with the right auth levels
477 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
478 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
479 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
480 editor_args = gen_args("org.couchers.editor.Editor", "CreateCommunity")
481 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
483 # pairs to check
484 checks = [
485 # name, args, token, works?, code, message
486 # open token only works on open servicers
487 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
488 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
489 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
490 ("open x editor", open_token, editor_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
491 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
492 # jailed works on jailed and open
493 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
494 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
495 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
496 ("jailed x editor", jailed_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
497 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
498 # normal works on all but editor and admin
499 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
500 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
501 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
502 ("normal x editor", normal_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
503 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
504 # editor works on all but admin
505 ("editor x open", editor_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
506 ("editor x jailed", editor_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
507 ("editor x secure", editor_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
508 ("editor x editor", editor_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
509 ("editor x admin", editor_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
510 # superuser works on all
511 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
512 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
513 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
514 ("super x editor", super_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
515 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
516 ]
518 for name, token, args, should_work, code, message in checks:
519 print(f"Testing (token x args) = ({name}), {should_work=}")
520 metadata = (("cookie", f"couchers-sesh={token}"),)
521 with interceptor_dummy_api(**args) as call_rpc:
522 if should_work:
523 call_rpc(empty_pb2.Empty(), metadata=metadata)
524 else:
525 with pytest.raises(grpc.RpcError) as e:
526 call_rpc(empty_pb2.Empty(), metadata=metadata)
527 assert e.value.code() == code
528 assert e.value.details() == message
530 # a non-existent RPC
531 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
533 with interceptor_dummy_api(**nonexistent) as call_rpc:
534 with pytest.raises(Exception) as e:
535 call_rpc(empty_pb2.Empty())
536 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED
537 assert e.value.details() == "API call does not exist. Please refresh and try again."
539 # an RPC without a service level
540 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
542 with interceptor_dummy_api(**invalid_args) as call_rpc:
543 with pytest.raises(Exception) as e:
544 call_rpc(empty_pb2.Empty())
545 assert e.value.code() == grpc.StatusCode.INTERNAL
546 assert e.value.details() == "Internal authentication error."