Coverage for src/tests/test_interceptors.py: 100%
276 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-22 06:42 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-22 06:42 +0000
1from concurrent import futures
2from contextlib import contextmanager
4import grpc
5import pytest
6from google.protobuf import empty_pb2
8from couchers import errors
9from couchers.crypto import random_hex
10from couchers.db import session_scope
11from couchers.interceptors import (
12 AuthValidatorInterceptor,
13 CookieInterceptor,
14 ErrorSanitizationInterceptor,
15 SessionInterceptor,
16 TracingInterceptor,
17)
18from couchers.metrics import servicer_duration_histogram
19from couchers.models import APICall, UserSession
20from couchers.servicers.account import Account
21from couchers.servicers.api import API
22from couchers.sql import couchers_select as select
23from proto import account_pb2, admin_pb2, api_pb2, auth_pb2
24from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa
27@pytest.fixture(autouse=True)
28def _(testconfig):
29 pass
32@contextmanager
33def interceptor_dummy_api(
34 rpc,
35 interceptors,
36 service_name="testing.Test",
37 method_name="TestRpc",
38 request_type=empty_pb2.Empty,
39 response_type=empty_pb2.Empty,
40 creds=None,
41):
42 with futures.ThreadPoolExecutor(1) as executor:
43 server = grpc.server(executor, interceptors=interceptors)
44 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
46 # manually add the handler
47 rpc_method_handlers = {
48 method_name: grpc.unary_unary_rpc_method_handler(
49 rpc,
50 request_deserializer=request_type.FromString,
51 response_serializer=response_type.SerializeToString,
52 )
53 }
54 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
55 server.add_generic_rpc_handlers((generic_handler,))
56 server.start()
58 try:
59 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
60 call_rpc = channel.unary_unary(
61 f"/{service_name}/{method_name}",
62 request_serializer=request_type.SerializeToString,
63 response_deserializer=response_type.FromString,
64 )
65 yield call_rpc
66 finally:
67 server.stop(None).wait()
70def _check_histogram_labels(method, logged_in, exception, code, count):
71 metrics = servicer_duration_histogram.collect()
72 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
73 histogram_count = [
74 s
75 for s in servicer_histogram.samples
76 if s.name == "couchers_servicer_duration_seconds_count"
77 and s.labels["method"] == method
78 and s.labels["logged_in"] == logged_in
79 and s.labels["code"] == code
80 and s.labels["exception"] == exception
81 ][0]
82 assert histogram_count.value == count
83 servicer_duration_histogram.clear()
86def test_logging_interceptor_ok():
87 def TestRpc(request, context):
88 return empty_pb2.Empty()
90 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
91 call_rpc(empty_pb2.Empty())
94def test_logging_interceptor_all_ignored():
95 # error codes that should not be touched by the interceptor
96 pass_through_status_codes = [
97 # we can't abort with OK
98 # grpc.StatusCode.OK,
99 grpc.StatusCode.CANCELLED,
100 grpc.StatusCode.UNKNOWN,
101 grpc.StatusCode.INVALID_ARGUMENT,
102 grpc.StatusCode.DEADLINE_EXCEEDED,
103 grpc.StatusCode.NOT_FOUND,
104 grpc.StatusCode.ALREADY_EXISTS,
105 grpc.StatusCode.PERMISSION_DENIED,
106 grpc.StatusCode.UNAUTHENTICATED,
107 grpc.StatusCode.RESOURCE_EXHAUSTED,
108 grpc.StatusCode.FAILED_PRECONDITION,
109 grpc.StatusCode.ABORTED,
110 grpc.StatusCode.OUT_OF_RANGE,
111 grpc.StatusCode.UNIMPLEMENTED,
112 grpc.StatusCode.INTERNAL,
113 grpc.StatusCode.UNAVAILABLE,
114 grpc.StatusCode.DATA_LOSS,
115 ]
117 for status_code in pass_through_status_codes:
118 message = random_hex()
120 def TestRpc(request, context):
121 context.abort(status_code, message) # noqa: B023
123 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
124 with pytest.raises(grpc.RpcError) as e:
125 call_rpc(empty_pb2.Empty())
126 assert e.value.code() == status_code
127 assert e.value.details() == message
130def test_logging_interceptor_assertion():
131 def TestRpc(request, context):
132 raise AssertionError()
134 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
135 with pytest.raises(grpc.RpcError) as e:
136 call_rpc(empty_pb2.Empty())
137 assert e.value.code() == grpc.StatusCode.INTERNAL
138 assert e.value.details() == errors.UNKNOWN_ERROR
141def test_logging_interceptor_div0():
142 def TestRpc(request, context):
143 1 / 0 # noqa: B018
145 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
146 with pytest.raises(grpc.RpcError) as e:
147 call_rpc(empty_pb2.Empty())
148 assert e.value.code() == grpc.StatusCode.INTERNAL
149 assert e.value.details() == errors.UNKNOWN_ERROR
152def test_logging_interceptor_raise():
153 def TestRpc(request, context):
154 raise Exception()
156 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
157 with pytest.raises(grpc.RpcError) as e:
158 call_rpc(empty_pb2.Empty())
159 assert e.value.code() == grpc.StatusCode.INTERNAL
160 assert e.value.details() == errors.UNKNOWN_ERROR
163def test_logging_interceptor_raise_custom():
164 class _TestingException(Exception):
165 pass
167 def TestRpc(request, context):
168 raise _TestingException("This is a custom exception")
170 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
171 with pytest.raises(grpc.RpcError) as e:
172 call_rpc(empty_pb2.Empty())
173 assert e.value.code() == grpc.StatusCode.INTERNAL
174 assert e.value.details() == errors.UNKNOWN_ERROR
177def test_tracing_interceptor_ok_open(db):
178 def TestRpc(request, context):
179 return empty_pb2.Empty()
181 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc:
182 call_rpc(empty_pb2.Empty())
184 with session_scope() as session:
185 trace = session.execute(select(APICall)).scalar_one()
186 assert trace.method == "/testing.Test/TestRpc"
187 assert not trace.status_code
188 assert not trace.user_id
189 assert len(trace.request) == 0
190 assert len(trace.response) == 0
191 assert not trace.traceback
193 _check_histogram_labels("/testing.Test/TestRpc", "False", "", "", 1)
196def test_tracing_interceptor_sensitive(db):
197 def TestRpc(request, context):
198 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
200 with interceptor_dummy_api(
201 TestRpc,
202 interceptors=[TracingInterceptor()],
203 request_type=auth_pb2.SignupFlowReq,
204 response_type=auth_pb2.AuthReq,
205 ) as call_rpc:
206 call_rpc(
207 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
208 )
210 with session_scope() as session:
211 trace = session.execute(select(APICall)).scalar_one()
212 assert trace.method == "/testing.Test/TestRpc"
213 assert not trace.status_code
214 assert not trace.user_id
215 assert not trace.traceback
216 req = auth_pb2.SignupFlowReq.FromString(trace.request)
217 assert not req.account.password
218 assert req.account.username == "not removed"
219 res = auth_pb2.AuthReq.FromString(trace.response)
220 assert res.user == "this is not secret"
221 assert not res.password
223 _check_histogram_labels("/testing.Test/TestRpc", "False", "", "", 1)
226def test_tracing_interceptor_sensitive_ping(db):
227 user, token = generate_user()
229 with interceptor_dummy_api(
230 API().GetUser,
231 interceptors=[TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
232 request_type=api_pb2.GetUserReq,
233 response_type=api_pb2.User,
234 service_name="org.couchers.api.core.API",
235 method_name="GetUser",
236 ) as call_rpc:
237 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
240def test_tracing_interceptor_exception(db):
241 def TestRpc(request, context):
242 raise Exception("Some error message")
244 with interceptor_dummy_api(
245 TestRpc,
246 interceptors=[TracingInterceptor()],
247 request_type=auth_pb2.SignupAccount,
248 response_type=auth_pb2.AuthReq,
249 ) as call_rpc:
250 with pytest.raises(Exception, match="Some error message"):
251 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
253 with session_scope() as session:
254 trace = session.execute(select(APICall)).scalar_one()
255 assert trace.method == "/testing.Test/TestRpc"
256 assert not trace.status_code
257 assert not trace.user_id
258 assert "Some error message" in trace.traceback
259 req = auth_pb2.SignupAccount.FromString(trace.request)
260 assert not req.password
261 assert req.username == "not removed"
262 assert not trace.response
264 _check_histogram_labels("/testing.Test/TestRpc", "False", "Exception", "", 1)
267def test_tracing_interceptor_abort(db):
268 def TestRpc(request, context):
269 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
271 with interceptor_dummy_api(
272 TestRpc,
273 interceptors=[TracingInterceptor()],
274 request_type=auth_pb2.SignupAccount,
275 response_type=auth_pb2.AuthReq,
276 ) as call_rpc:
277 with pytest.raises(Exception, match="now a grpc abort"):
278 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
280 with session_scope() as session:
281 trace = session.execute(select(APICall)).scalar_one()
282 assert trace.method == "/testing.Test/TestRpc"
283 assert trace.status_code == "FAILED_PRECONDITION"
284 assert not trace.user_id
285 assert "now a grpc abort" in trace.traceback
286 req = auth_pb2.SignupAccount.FromString(trace.request)
287 assert not req.password
288 assert req.username == "not removed"
289 assert not trace.response
291 _check_histogram_labels("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION", 1)
294def test_auth_interceptor(db):
295 super_user, super_token = generate_user(is_superuser=True)
296 user, token = generate_user()
298 with real_admin_session(super_token) as api:
299 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
301 with session_scope() as session:
302 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
303 api_key = api_session.token
305 account = Account()
307 rpc_def = {
308 "rpc": account.GetAccountInfo,
309 "service_name": "org.couchers.api.account.Account",
310 "method_name": "GetAccountInfo",
311 "interceptors": [AuthValidatorInterceptor(), CookieInterceptor(), SessionInterceptor()],
312 "request_type": empty_pb2.Empty,
313 "response_type": account_pb2.GetAccountInfoRes,
314 }
316 # no creds, no go for secure APIs
317 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
318 with pytest.raises(grpc.RpcError) as e:
319 call_rpc(empty_pb2.Empty())
320 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
321 assert e.value.details() == "Unauthorized"
323 # can auth with cookie
324 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
325 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
326 assert res1.username == user.username
328 # can't auth with wrong cookie
329 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
330 with pytest.raises(grpc.RpcError) as e:
331 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),))
332 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
333 assert e.value.details() == "Unauthorized"
335 # can auth with api key
336 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
337 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
338 assert res2.username == user.username
340 # can't auth with wrong api key
341 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
342 with pytest.raises(grpc.RpcError) as e:
343 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),))
344 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
345 assert e.value.details() == "Unauthorized"
347 # can auth with grpc helper (they do the same as above)
348 comp_creds = grpc.composite_channel_credentials(
349 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
350 )
351 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
352 res3 = call_rpc(empty_pb2.Empty())
353 assert res3.username == user.username
355 # can't auth with both
356 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
357 with pytest.raises(grpc.RpcError) as e:
358 call_rpc(
359 empty_pb2.Empty(),
360 metadata=(
361 ("cookie", f"couchers-sesh={token}"),
362 ("authorization", f"Bearer {api_key}"),
363 ),
364 )
365 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
366 assert e.value.details() == 'Both "cookie" and "authorization" in request'
368 # malformed bearer
369 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
370 with pytest.raises(grpc.RpcError) as e:
371 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
372 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
373 assert e.value.details() == "Unauthorized"
376def test_tracing_interceptor_auth_cookies(db):
377 user, token = generate_user()
379 account = Account()
381 rpc_def = {
382 "rpc": account.GetAccountInfo,
383 "service_name": "org.couchers.api.account.Account",
384 "method_name": "GetAccountInfo",
385 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
386 "request_type": empty_pb2.Empty,
387 "response_type": account_pb2.GetAccountInfoRes,
388 }
390 # with cookies
391 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
392 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
393 assert res1.username == user.username
395 with session_scope() as session:
396 trace = session.execute(select(APICall)).scalar_one()
397 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
398 assert not trace.status_code
399 assert trace.user_id == user.id
400 assert not trace.is_api_key
401 assert len(trace.request) == 0
402 assert not trace.traceback
405def test_tracing_interceptor_auth_api_key(db):
406 super_user, super_token = generate_user(is_superuser=True)
407 user, token = generate_user()
409 with real_admin_session(super_token) as api:
410 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
412 with session_scope() as session:
413 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one()
414 api_key = api_session.token
416 account = Account()
418 rpc_def = {
419 "rpc": account.GetAccountInfo,
420 "service_name": "org.couchers.api.account.Account",
421 "method_name": "GetAccountInfo",
422 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()],
423 "request_type": empty_pb2.Empty,
424 "response_type": account_pb2.GetAccountInfoRes,
425 }
427 # with api key
428 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
429 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
430 assert res1.username == user.username
432 with session_scope() as session:
433 trace = session.execute(select(APICall)).scalar_one()
434 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
435 assert not trace.status_code
436 assert trace.user_id == user.id
437 assert trace.is_api_key
438 assert len(trace.request) == 0
439 assert not trace.traceback
442def test_auth_levels(db):
443 def TestRpc(request, context):
444 return empty_pb2.Empty()
446 def gen_args(service, method):
447 return {
448 "rpc": TestRpc,
449 "service_name": service,
450 "method_name": method,
451 "interceptors": [AuthValidatorInterceptor()],
452 "request_type": empty_pb2.Empty,
453 "response_type": empty_pb2.Empty,
454 }
456 # superuser
457 _, super_token = generate_user(is_superuser=True)
458 # normal user
459 _, normal_token = generate_user()
460 # jailed user
461 _, jailed_token = generate_user(accepted_tos=0)
462 # open user
463 open_token = ""
465 # pick some rpcs here with the right auth levels
466 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
467 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
468 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
469 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
471 # pairs to check
472 checks = [
473 # name, args, token, works?, code, message
474 # open token only works on open servicers
475 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
476 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
477 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
478 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
479 # jailed works on jailed and open
480 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
481 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
482 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
483 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
484 # normal works on all but admin
485 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
486 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
487 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
488 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
489 # superuser works on all
490 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
491 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
492 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
493 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
494 ]
496 for name, token, args, should_work, code, message in checks:
497 print(f"Testing (token x args) = ({name}), {should_work=}")
498 metadata = (("cookie", f"couchers-sesh={token}"),)
499 with interceptor_dummy_api(**args) as call_rpc:
500 if should_work:
501 call_rpc(empty_pb2.Empty(), metadata=metadata)
502 else:
503 with pytest.raises(grpc.RpcError) as e:
504 call_rpc(empty_pb2.Empty(), metadata=metadata)
505 assert e.value.code() == code
506 assert e.value.details() == message
508 # a non-existent RPC
509 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
511 with interceptor_dummy_api(**nonexistent) as call_rpc:
512 with pytest.raises(Exception) as e:
513 call_rpc(empty_pb2.Empty())
514 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED
515 assert e.value.details() == "API call does not exist. Please refresh and try again."
517 # an RPC without a service level
518 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
520 with interceptor_dummy_api(**invalid_args) as call_rpc:
521 with pytest.raises(Exception) as e:
522 call_rpc(empty_pb2.Empty())
523 assert e.value.code() == grpc.StatusCode.INTERNAL
524 assert e.value.details() == "Internal authentication error."