Coverage for app / backend / src / tests / test_interceptors.py: 100%
584 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 14:14 +0000
1from collections.abc import Callable, Generator
2from concurrent import futures
3from contextlib import contextmanager
4from datetime import timedelta
5from typing import Any
6from unittest.mock import Mock
8import grpc
9import pytest
10from google.protobuf import empty_pb2
11from google.protobuf.descriptor import ServiceDescriptor
12from google.protobuf.descriptor_pool import DescriptorPool
13from sqlalchemy import select, update
15from couchers.constants import (
16 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
17 NONEXISTENT_API_CALL_ERROR_MESSAGE,
18)
19from couchers.crypto import b64encode, random_hex, simple_encrypt
20from couchers.db import session_scope
21from couchers.descriptor_pool import get_descriptor_pool
22from couchers.interceptors import (
23 AbortError,
24 BadHeaders,
25 CouchersMiddlewareInterceptor,
26 ErrorSanitizationInterceptor,
27 UserAuthInfo,
28 check_permissions,
29 find_auth_level,
30 parse_headers,
31 validate_auth_level,
32)
33from couchers.metrics import servicer_duration_histogram
34from couchers.models import APICall, User, UserActivity, UserSession
35from couchers.proto import account_pb2, admin_pb2, annotations_pb2, api_pb2, auth_pb2
36from couchers.servicers.account import Account
37from couchers.servicers.api import API
38from couchers.utils import generate_sofa_cookie, now, parse_sofa_cookie
39from tests.fixtures.db import generate_user
40from tests.fixtures.sessions import real_admin_session
43@pytest.fixture(autouse=True)
44def _(testconfig):
45 pass
48@contextmanager
49def interceptor_dummy_api(
50 rpc,
51 interceptors,
52 service_name="org.couchers.auth.Auth",
53 method_name="SignupFlow",
54 request_type=empty_pb2.Empty,
55 response_type=empty_pb2.Empty,
56 creds=None,
57) -> Generator[Callable[..., Any]]:
58 with futures.ThreadPoolExecutor(1) as executor:
59 server = grpc.server(executor, interceptors=interceptors)
60 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
62 # manually add the handler
63 rpc_method_handlers = {
64 method_name: grpc.unary_unary_rpc_method_handler(
65 rpc,
66 request_deserializer=request_type.FromString,
67 response_serializer=response_type.SerializeToString,
68 )
69 }
70 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
71 server.add_generic_rpc_handlers((generic_handler,))
72 server.start()
74 try:
75 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
76 yield channel.unary_unary(
77 f"/{service_name}/{method_name}",
78 request_serializer=request_type.SerializeToString,
79 response_deserializer=response_type.FromString,
80 )
81 finally:
82 server.stop(None).wait()
85def _get_histogram_labels_value(method, logged_in, exception, code):
86 metrics = servicer_duration_histogram.collect()
87 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
88 histogram_counts = [
89 s
90 for s in servicer_histogram.samples
91 if s.name == "couchers_servicer_duration_seconds_count"
92 and s.labels["method"] == method
93 and s.labels["logged_in"] == logged_in
94 and s.labels["code"] == code
95 and s.labels["exception"] == exception
96 ]
97 if len(histogram_counts) == 0:
98 return 0
99 return histogram_counts[0].value
102def test_logging_interceptor_ok():
103 def TestRpc(request, context):
104 return empty_pb2.Empty()
106 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
107 call_rpc(empty_pb2.Empty())
110def test_logging_interceptor_all_ignored():
111 # error codes that should not be touched by the interceptor
112 pass_through_status_codes = [
113 # we can't abort with OK
114 # grpc.StatusCode.OK,
115 grpc.StatusCode.CANCELLED,
116 grpc.StatusCode.UNKNOWN,
117 grpc.StatusCode.INVALID_ARGUMENT,
118 grpc.StatusCode.DEADLINE_EXCEEDED,
119 grpc.StatusCode.NOT_FOUND,
120 grpc.StatusCode.ALREADY_EXISTS,
121 grpc.StatusCode.PERMISSION_DENIED,
122 grpc.StatusCode.UNAUTHENTICATED,
123 grpc.StatusCode.RESOURCE_EXHAUSTED,
124 grpc.StatusCode.FAILED_PRECONDITION,
125 grpc.StatusCode.ABORTED,
126 grpc.StatusCode.OUT_OF_RANGE,
127 grpc.StatusCode.UNIMPLEMENTED,
128 grpc.StatusCode.INTERNAL,
129 grpc.StatusCode.UNAVAILABLE,
130 grpc.StatusCode.DATA_LOSS,
131 ]
133 for status_code in pass_through_status_codes:
134 message = random_hex()
136 def TestRpc(request, context):
137 context.abort(status_code, message) # noqa: B023
139 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
140 with pytest.raises(grpc.RpcError) as e:
141 call_rpc(empty_pb2.Empty())
142 assert e.value.code() == status_code
143 assert e.value.details() == message
146def test_logging_interceptor_assertion():
147 def TestRpc(request, context):
148 raise AssertionError()
150 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
151 with pytest.raises(grpc.RpcError) as e:
152 call_rpc(empty_pb2.Empty())
153 assert e.value.code() == grpc.StatusCode.INTERNAL
154 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
157def test_logging_interceptor_div0():
158 def TestRpc(request, context):
159 1 / 0 # noqa: B018
161 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
162 with pytest.raises(grpc.RpcError) as e:
163 call_rpc(empty_pb2.Empty())
164 assert e.value.code() == grpc.StatusCode.INTERNAL
165 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
168def test_logging_interceptor_raise():
169 def TestRpc(request, context):
170 raise Exception()
172 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
173 with pytest.raises(grpc.RpcError) as e:
174 call_rpc(empty_pb2.Empty())
175 assert e.value.code() == grpc.StatusCode.INTERNAL
176 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
179def test_logging_interceptor_raise_custom():
180 class _TestingException(Exception):
181 pass
183 def TestRpc(request, context):
184 raise _TestingException("This is a custom exception")
186 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
187 with pytest.raises(grpc.RpcError) as e:
188 call_rpc(empty_pb2.Empty())
189 assert e.value.code() == grpc.StatusCode.INTERNAL
190 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
193def test_tracing_interceptor_ok_open(db):
194 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
196 def TestRpc(request, context, session):
197 return empty_pb2.Empty()
199 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
200 call_rpc(empty_pb2.Empty())
202 with session_scope() as session:
203 trace = session.execute(select(APICall)).scalar_one()
204 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
205 assert not trace.status_code
206 assert not trace.user_id
207 assert trace.request is not None
208 assert len(trace.request) == 0
209 assert trace.response is not None
210 assert len(trace.response) == 0
211 assert not trace.traceback
213 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
216def test_tracing_interceptor_sensitive(db):
217 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
219 def TestRpc(request, context, session):
220 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
222 with interceptor_dummy_api(
223 TestRpc,
224 interceptors=[CouchersMiddlewareInterceptor()],
225 request_type=auth_pb2.SignupFlowReq,
226 response_type=auth_pb2.AuthReq,
227 ) as call_rpc:
228 call_rpc(
229 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
230 )
232 with session_scope() as session:
233 trace = session.execute(select(APICall)).scalar_one()
234 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
235 assert not trace.status_code
236 assert not trace.user_id
237 assert not trace.traceback
238 assert trace.request is not None
239 req = auth_pb2.SignupFlowReq.FromString(trace.request)
240 assert not req.account.password
241 assert req.account.username == "not removed"
242 assert trace.response
243 res = auth_pb2.AuthReq.FromString(trace.response)
244 assert res.user == "this is not secret"
245 assert not res.password
247 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
250def test_tracing_interceptor_sensitive_ping(db):
251 user, token = generate_user()
253 with interceptor_dummy_api(
254 API().GetUser,
255 interceptors=[CouchersMiddlewareInterceptor()],
256 request_type=api_pb2.GetUserReq,
257 response_type=api_pb2.User,
258 service_name="org.couchers.api.core.API",
259 method_name="GetUser",
260 ) as call_rpc:
261 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
264def test_tracing_interceptor_exception(db):
265 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "")
267 def TestRpc(request, context, session):
268 raise Exception("Some error message")
270 with interceptor_dummy_api(
271 TestRpc,
272 interceptors=[CouchersMiddlewareInterceptor()],
273 request_type=auth_pb2.SignupAccount,
274 response_type=auth_pb2.AuthReq,
275 ) as call_rpc:
276 with pytest.raises(Exception, match="Some error message"):
277 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
279 with session_scope() as session:
280 trace = session.execute(select(APICall)).scalar_one()
281 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
282 assert not trace.status_code
283 assert not trace.user_id
284 assert trace.traceback
285 assert "Some error message" in trace.traceback
286 assert trace.request is not None
287 req = auth_pb2.SignupAccount.FromString(trace.request)
288 assert not req.password
289 assert req.username == "not removed"
290 assert not trace.response
292 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1
295def test_tracing_interceptor_abort(db):
296 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
298 def TestRpc(request, context, session):
299 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
301 with interceptor_dummy_api(
302 TestRpc,
303 interceptors=[CouchersMiddlewareInterceptor()],
304 request_type=auth_pb2.SignupAccount,
305 response_type=auth_pb2.AuthReq,
306 ) as call_rpc:
307 with pytest.raises(Exception, match="now a grpc abort"):
308 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
310 with session_scope() as session:
311 trace = session.execute(select(APICall)).scalar_one()
312 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
313 assert trace.status_code == "FAILED_PRECONDITION"
314 assert not trace.user_id
315 assert trace.traceback
316 assert "now a grpc abort" in trace.traceback
317 assert trace.request is not None
318 req = auth_pb2.SignupAccount.FromString(trace.request)
319 assert not req.password
320 assert req.username == "not removed"
321 assert not trace.response
323 assert (
324 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
325 == val + 1
326 )
329def cookie_auth(token: str) -> tuple[str, str]:
330 return "cookie", f"couchers-sesh={token}"
333def api_auth(token: str) -> tuple[str, str]:
334 return "authorization", f"Bearer {token}"
337def test_auth_interceptor(db):
338 super_user, super_token = generate_user(is_superuser=True)
339 user, token = generate_user()
340 deleted_user, deleted_token = generate_user(delete_user=True)
342 with real_admin_session(super_token) as api:
343 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
345 with session_scope() as session:
346 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
348 account = Account()
350 rpc_def = {
351 "rpc": account.GetAccountInfo,
352 "service_name": "org.couchers.api.account.Account",
353 "method_name": "GetAccountInfo",
354 "interceptors": [CouchersMiddlewareInterceptor()],
355 "request_type": empty_pb2.Empty,
356 "response_type": account_pb2.GetAccountInfoRes,
357 }
359 # no creds, no-go for secure APIs
360 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
361 with pytest.raises(grpc.RpcError) as e:
362 call_rpc(empty_pb2.Empty())
363 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
364 assert e.value.details() == "Unauthorized"
366 # can auth with cookie
367 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
368 res1 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
369 assert res1.username == user.username
371 with session_scope() as session:
372 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
373 assert api_calls == 1
375 # can't auth with a wrong cookie
376 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
377 with pytest.raises(grpc.RpcError) as e:
378 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(random_hex(32)),))
379 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
380 assert e.value.details() == "Unauthorized"
382 # can auth with an api key
383 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
384 res2 = call_rpc(empty_pb2.Empty(), metadata=(api_auth(api_key),))
385 assert res2.username == user.username
387 with session_scope() as session:
388 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
389 assert api_calls == 2
391 # can't auth with a wrong api key
392 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
393 with pytest.raises(grpc.RpcError) as e:
394 call_rpc(empty_pb2.Empty(), metadata=(api_auth(random_hex(32)),))
395 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
396 assert e.value.details() == "Unauthorized"
398 # can auth with grpc helper (they do the same as above)
399 comp_creds = grpc.composite_channel_credentials(
400 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
401 )
402 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
403 res3 = call_rpc(empty_pb2.Empty())
404 assert res3.username == user.username
406 # can't auth with both
407 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
408 with pytest.raises(grpc.RpcError) as e:
409 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), api_auth(api_key)))
410 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
411 assert e.value.details() == 'Both "cookie" and "authorization" in request'
413 # malformed bearer
414 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
415 with pytest.raises(grpc.RpcError) as e:
416 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
417 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
418 assert e.value.details() == "Unauthorized"
420 # Invisible (deleted) user
421 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
422 with pytest.raises(grpc.RpcError) as e:
423 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(deleted_token),))
424 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
425 assert e.value.details() == "Unauthorized"
427 # Invalid (expired) session
428 long_ago = now() - timedelta(weeks=100)
429 with session_scope() as session:
430 session.execute(update(UserSession).values(last_seen=long_ago).where(UserSession.token == token))
432 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
433 with pytest.raises(grpc.RpcError) as e:
434 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
435 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
436 assert e.value.details() == "Unauthorized"
438 # API key token, but session is for session cookie (probably impossible, but...)
439 with session_scope() as session:
440 session.execute(update(UserSession).values(last_seen=now(), is_api_key=True).where(UserSession.token == token))
442 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
443 with pytest.raises(grpc.RpcError) as e:
444 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
445 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
446 assert e.value.details() == "Unauthorized"
448 # Check that metadata are updated
449 six_minutes_ago = now() - timedelta(minutes=6)
450 with session_scope() as session:
451 # Return the session to normal
452 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
453 user_session.is_api_key = False
454 api_calls = user_session.api_calls
456 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
457 res4 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
458 assert res4.username == user.username
460 with session_scope() as session:
461 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
462 assert user_session.api_calls == api_calls + 1
463 assert user_session.last_seen > now() - timedelta(seconds=1)
465 # Simulate user inactivity, so last_active is updated on the next api call.
466 session.execute(update(User).values(last_active=six_minutes_ago).where(User.id == user.id))
468 # Check that last_active is updated if it wasn't updated in a while.
469 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
470 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
472 with session_scope() as session:
473 last_active = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
474 assert last_active > now() - timedelta(seconds=1)
476 # Check that last_active is untouched (since it was already updated recently)
477 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
478 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
480 with session_scope() as session:
481 last_active_2 = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
482 assert last_active_2 == last_active
484 # Check that activity is split by IP.
485 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
486 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), ("x-couchers-real-ip", "1.1.1.1")))
488 with session_scope() as session:
489 api_calls = session.execute(
490 select(UserActivity.api_calls).where(UserActivity.ip_address == "1.1.1.1")
491 ).scalar_one()
492 assert api_calls == 1
494 # Check that activity is split in time bins.
495 # Update all UserActivity to be in the far past so that a new row is inserted on the next request.
496 with session_scope() as session:
497 session.execute(update(UserActivity).values(period=long_ago).where(UserActivity.user_id == user.id))
499 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
500 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
502 with session_scope() as session:
503 api_calls = session.execute(
504 select(UserActivity.api_calls)
505 .where(UserActivity.user_id == user.id)
506 .order_by(UserActivity.id.desc())
507 .limit(1)
508 ).scalar_one()
509 assert api_calls == 1
512def test_tracing_interceptor_auth_cookies(db):
513 user, token = generate_user()
515 account = Account()
517 rpc_def = {
518 "rpc": account.GetAccountInfo,
519 "service_name": "org.couchers.api.account.Account",
520 "method_name": "GetAccountInfo",
521 "interceptors": [CouchersMiddlewareInterceptor()],
522 "request_type": empty_pb2.Empty,
523 "response_type": account_pb2.GetAccountInfoRes,
524 }
526 # with cookies
527 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
528 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
529 assert res1.username == user.username
531 with session_scope() as session:
532 trace = session.execute(select(APICall)).scalar_one()
533 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
534 assert not trace.status_code
535 assert trace.user_id == user.id
536 assert not trace.is_api_key
537 assert trace.request is not None
538 assert len(trace.request) == 0
539 assert not trace.traceback
542def test_tracing_interceptor_auth_api_key(db):
543 super_user, super_token = generate_user(is_superuser=True)
544 user, token = generate_user()
546 with real_admin_session(super_token) as api:
547 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
549 with session_scope() as session:
550 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
552 account = Account()
554 rpc_def = {
555 "rpc": account.GetAccountInfo,
556 "service_name": "org.couchers.api.account.Account",
557 "method_name": "GetAccountInfo",
558 "interceptors": [CouchersMiddlewareInterceptor()],
559 "request_type": empty_pb2.Empty,
560 "response_type": account_pb2.GetAccountInfoRes,
561 }
563 # with api key
564 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
565 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
566 assert res1.username == user.username
568 with session_scope() as session:
569 trace = session.execute(
570 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo")
571 ).scalar_one()
572 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
573 assert not trace.status_code
574 assert trace.user_id == user.id
575 assert trace.is_api_key
576 assert trace.request is not None
577 assert len(trace.request) == 0
578 assert not trace.traceback
581def test_auth_levels(db):
582 def TestRpc(request, context, session):
583 return empty_pb2.Empty()
585 def gen_args(service, method):
586 return {
587 "rpc": TestRpc,
588 "service_name": service,
589 "method_name": method,
590 "interceptors": [CouchersMiddlewareInterceptor()],
591 "request_type": empty_pb2.Empty,
592 "response_type": empty_pb2.Empty,
593 }
595 # superuser (note: superusers are automatically editors due to DB constraint)
596 _, super_token = generate_user(is_superuser=True)
597 # editor user
598 _, editor_token = generate_user(is_editor=True)
599 # normal user
600 _, normal_token = generate_user()
601 # jailed user
602 _, jailed_token = generate_user(accepted_tos=0)
603 # open user
604 open_token = ""
606 # pick some rpcs here with the right auth levels
607 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
608 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
609 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
610 editor_args = gen_args("org.couchers.editor.Editor", "CreateCommunity")
611 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
613 # pairs to check
614 checks = [
615 # name, args, token, works?, code, message
616 # open token only works on open servicers
617 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
618 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
619 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
620 ("open x editor", open_token, editor_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
621 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
622 # jailed works on jailed and open
623 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
624 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
625 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
626 ("jailed x editor", jailed_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
627 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
628 # normal works on all but editor and admin
629 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
630 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
631 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
632 ("normal x editor", normal_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
633 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
634 # editor works on all but admin
635 ("editor x open", editor_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
636 ("editor x jailed", editor_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
637 ("editor x secure", editor_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
638 ("editor x editor", editor_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
639 ("editor x admin", editor_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
640 # superuser works on all
641 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
642 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
643 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
644 ("super x editor", super_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
645 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
646 ]
648 for name, token, args, should_work, code, message in checks:
649 print(f"Testing (token x args) = ({name}), {should_work=}")
650 metadata = (("cookie", f"couchers-sesh={token}"),)
651 with interceptor_dummy_api(**args) as call_rpc:
652 if should_work:
653 call_rpc(empty_pb2.Empty(), metadata=metadata)
654 else:
655 with pytest.raises(grpc.RpcError) as err:
656 call_rpc(empty_pb2.Empty(), metadata=metadata)
657 assert err.value.code() == code
658 assert err.value.details() == message
660 # a non-existent RPC
661 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
663 with interceptor_dummy_api(**nonexistent) as call_rpc:
664 with pytest.raises(grpc.RpcError) as err:
665 call_rpc(empty_pb2.Empty())
666 assert err.value.code() == grpc.StatusCode.UNIMPLEMENTED
667 assert err.value.details() == "API call does not exist. Please refresh and try again."
669 # an RPC without a service level
670 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
672 with interceptor_dummy_api(**invalid_args) as call_rpc:
673 with pytest.raises(grpc.RpcError) as err:
674 call_rpc(empty_pb2.Empty())
675 assert err.value.code() == grpc.StatusCode.INTERNAL
676 assert err.value.details() == "Internal authentication error."
679def test_parse_headers_with_session_cookie():
680 headers = {"cookie": "couchers-sesh=abc123; other-cookie=value"}
681 result = parse_headers(headers)
682 assert result.token == "abc123"
683 assert result.is_api_key is False
686def test_parse_headers_with_authorization_header():
687 headers = {"authorization": "Bearer abc123"}
688 result = parse_headers(headers)
689 assert result.token == "abc123"
690 assert result.is_api_key is True
693def test_parse_headers_with_both_cookie_and_authorization():
694 headers = {"cookie": "couchers-sesh=abc123", "authorization": "Bearer xyz789"}
695 with pytest.raises(BadHeaders, match="Both cookies and authorization are present in headers"):
696 parse_headers(headers)
699def test_parse_headers_with_neither_cookie_nor_authorization():
700 result = parse_headers({})
701 assert result.token is None
702 assert result.is_api_key is False
705def test_parse_headers_with_all_optional_headers():
706 headers = {
707 "cookie": "couchers-sesh=abc123; couchers-user-id=42; NEXT_LOCALE=en",
708 "x-couchers-real-ip": "192.168.1.1",
709 "user-agent": "TestAgent/1.0",
710 }
711 result = parse_headers(headers)
712 assert result.token == "abc123"
713 assert result.is_api_key is False
714 assert result.ip_address == "192.168.1.1"
715 assert result.user_agent == "TestAgent/1.0"
716 assert result.ui_lang == "en"
717 assert result.user_id == "42"
720def test_parse_headers_with_bytes_ip_address():
721 headers: dict[str, str | bytes] = {
722 "cookie": "couchers-sesh=abc123",
723 "x-couchers-real-ip": b"192.168.1.1",
724 }
725 result = parse_headers(headers)
726 assert result.ip_address is None
729def test_parse_headers_with_bytes_user_agent():
730 headers: dict[str, str | bytes] = {
731 "cookie": "couchers-sesh=abc123",
732 "user-agent": b"TestAgent/1.0",
733 }
734 result = parse_headers(headers)
735 assert result.user_agent is None
738def test_parse_headers_malformed_authorization():
739 headers = {"authorization": "bearer abc123"}
740 result = parse_headers(headers)
741 assert result.token is None
742 assert result.is_api_key is True
745def test_find_auth_level_with_valid_service():
746 pool = get_descriptor_pool()
748 result = find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
749 assert result == annotations_pb2.AUTH_LEVEL_SECURE
752def test_find_auth_level_with_nonexistent_service():
753 pool = get_descriptor_pool()
755 with pytest.raises(AbortError) as exc:
756 find_auth_level(pool, "/org.couchers.nonexistent.Service/Method")
757 assert exc.value.msg == NONEXISTENT_API_CALL_ERROR_MESSAGE
758 assert exc.value.code == grpc.StatusCode.UNIMPLEMENTED
761def test_find_auth_level_with_unknown_auth_level():
762 pool = Mock(spec=DescriptorPool)
763 service_desc = Mock(spec=ServiceDescriptor)
764 service_options = Mock()
765 service_options.Extensions = {annotations_pb2.auth_level: annotations_pb2.AUTH_LEVEL_UNKNOWN}
766 service_desc.GetOptions.return_value = service_options
767 pool.FindServiceByName.return_value = service_desc
769 with pytest.raises(AbortError) as exc:
770 find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
771 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
772 assert exc.value.code == grpc.StatusCode.INTERNAL
775def test_validate_auth_level_with_unknown():
776 with pytest.raises(AbortError) as exc:
777 validate_auth_level(annotations_pb2.AUTH_LEVEL_UNKNOWN)
778 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
779 assert exc.value.code == grpc.StatusCode.INTERNAL
782def test_validate_auth_level_with_open():
783 validate_auth_level(annotations_pb2.AUTH_LEVEL_OPEN)
786def test_validate_auth_level_with_jailed():
787 validate_auth_level(annotations_pb2.AUTH_LEVEL_JAILED)
790def test_validate_auth_level_with_secure():
791 validate_auth_level(annotations_pb2.AUTH_LEVEL_SECURE)
794def test_validate_auth_level_with_editor():
795 validate_auth_level(annotations_pb2.AUTH_LEVEL_EDITOR)
798def test_validate_auth_level_with_admin():
799 validate_auth_level(annotations_pb2.AUTH_LEVEL_ADMIN)
802def test_check_auth_open_service_without_auth():
803 check_permissions(None, annotations_pb2.AUTH_LEVEL_OPEN)
806def test_check_auth_open_service_with_auth():
807 auth_info = UserAuthInfo(
808 user_id=1,
809 is_jailed=False,
810 is_editor=False,
811 is_superuser=False,
812 token_expiry=now(),
813 ui_language_preference="en",
814 timezone="Etc/UTC",
815 token="abc123",
816 is_api_key=False,
817 )
818 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_OPEN)
821def test_check_auth_secure_service_without_auth():
822 with pytest.raises(AbortError):
823 check_permissions(None, annotations_pb2.AUTH_LEVEL_SECURE)
826def test_check_auth_secure_service_with_normal_auth():
827 auth_info = UserAuthInfo(
828 user_id=1,
829 is_jailed=False,
830 is_editor=False,
831 is_superuser=False,
832 token_expiry=now(),
833 ui_language_preference="en",
834 timezone="Etc/UTC",
835 token="abc123",
836 is_api_key=False,
837 )
838 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
841def test_check_auth_secure_service_with_jailed_user():
842 auth_info = UserAuthInfo(
843 user_id=1,
844 is_jailed=True,
845 is_editor=False,
846 is_superuser=False,
847 token_expiry=now(),
848 ui_language_preference="en",
849 timezone="Etc/UTC",
850 token="abc123",
851 is_api_key=False,
852 )
853 with pytest.raises(AbortError):
854 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
857def test_check_auth_jailed_service_with_jailed_user():
858 auth_info = UserAuthInfo(
859 user_id=1,
860 is_jailed=True,
861 is_editor=False,
862 is_superuser=False,
863 token_expiry=now(),
864 ui_language_preference="en",
865 timezone="Etc/UTC",
866 token="abc123",
867 is_api_key=False,
868 )
869 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_JAILED)
872def test_check_auth_jailed_service_without_auth():
873 with pytest.raises(AbortError):
874 check_permissions(None, annotations_pb2.AUTH_LEVEL_JAILED)
877def test_check_auth_editor_service_without_editor():
878 auth_info = UserAuthInfo(
879 user_id=1,
880 is_jailed=False,
881 is_editor=False,
882 is_superuser=False,
883 token_expiry=now(),
884 ui_language_preference="en",
885 timezone="Etc/UTC",
886 token="abc123",
887 is_api_key=False,
888 )
889 with pytest.raises(AbortError):
890 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
893def test_check_auth_editor_service_with_editor():
894 auth_info = UserAuthInfo(
895 user_id=1,
896 is_jailed=False,
897 is_editor=True,
898 is_superuser=False,
899 token_expiry=now(),
900 ui_language_preference="en",
901 timezone="Etc/UTC",
902 token="abc123",
903 is_api_key=False,
904 )
905 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
908def test_check_auth_admin_service_without_superuser():
909 auth_info = UserAuthInfo(
910 user_id=1,
911 is_jailed=False,
912 is_editor=True,
913 is_superuser=False,
914 token_expiry=now(),
915 ui_language_preference="en",
916 timezone="Etc/UTC",
917 token="abc123",
918 is_api_key=False,
919 )
920 with pytest.raises(AbortError):
921 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
924def test_check_auth_admin_service_with_superuser():
925 auth_info = UserAuthInfo(
926 user_id=1,
927 is_jailed=False,
928 is_editor=True,
929 is_superuser=True,
930 token_expiry=now(),
931 ui_language_preference="en",
932 timezone="Etc/UTC",
933 token="abc123",
934 is_api_key=False,
935 )
936 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
939def test_check_auth_admin_service_without_auth():
940 with pytest.raises(AbortError):
941 check_permissions(None, annotations_pb2.AUTH_LEVEL_ADMIN)
944def test_parse_sofa_cookie_valid():
945 sofa_value, cookie_string = generate_sofa_cookie()
946 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
948 headers = {"cookie": f"sofa={cookie_value}"}
949 result = parse_sofa_cookie(headers)
950 assert result == sofa_value
953def test_parse_sofa_cookie_missing():
954 headers = {"cookie": "other-cookie=value"}
955 result = parse_sofa_cookie(headers)
956 assert result is None
959def test_parse_sofa_cookie_no_cookies():
960 headers: dict[str, str] = {}
961 result = parse_sofa_cookie(headers)
962 assert result is None
965def test_parse_sofa_cookie_invalid_base64():
966 headers = {"cookie": "sofa=not-valid-base64!!!"}
967 result = parse_sofa_cookie(headers)
968 assert result is None
971def test_parse_sofa_cookie_invalid_encryption():
972 headers = {"cookie": f"sofa={b64encode(b'invalid encrypted data')}"}
973 result = parse_sofa_cookie(headers)
974 assert result is None
977def test_parse_sofa_cookie_invalid_proto():
978 encrypted = simple_encrypt("sofa_cookie", b"not a valid proto")
979 headers = {"cookie": f"sofa={b64encode(encrypted)}"}
980 result = parse_sofa_cookie(headers)
981 assert result is not None or result is None
984def test_generate_sofa_cookie():
985 sofa_value, cookie_string = generate_sofa_cookie()
987 assert sofa_value
988 assert isinstance(sofa_value, str)
989 assert len(sofa_value) > 20
991 assert "sofa=" in cookie_string
992 assert "expires=" in cookie_string.lower()
994 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
995 headers = {"cookie": f"sofa={cookie_value}"}
996 parsed_value = parse_sofa_cookie(headers)
997 assert parsed_value == sofa_value
1000def test_parse_headers_with_sofa_cookie():
1001 sofa_value, cookie_string = generate_sofa_cookie()
1002 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1004 headers = {
1005 "cookie": f"couchers-sesh=abc123; sofa={cookie_value}",
1006 }
1007 result = parse_headers(headers)
1008 assert result.token == "abc123"
1009 assert result.sofa == sofa_value
1012def test_parse_headers_without_sofa_cookie():
1013 headers = {
1014 "cookie": "couchers-sesh=abc123",
1015 }
1016 result = parse_headers(headers)
1017 assert result.token == "abc123"
1018 assert result.sofa is None
1021def test_sofa_cookie_logged_new(db):
1022 def TestRpc(request, context, session):
1023 return empty_pb2.Empty()
1025 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1026 call_rpc(empty_pb2.Empty())
1028 with session_scope() as session:
1029 trace = session.execute(select(APICall)).scalar_one()
1030 assert trace.sofa is not None
1031 assert len(trace.sofa) > 20
1034def test_sofa_cookie_logged_existing(db):
1035 sofa_value, cookie_string = generate_sofa_cookie()
1036 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1038 def TestRpc(request, context, session):
1039 return empty_pb2.Empty()
1041 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1042 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1044 with session_scope() as session:
1045 trace = session.execute(select(APICall)).scalar_one()
1046 assert trace.sofa == sofa_value
1049def test_sofa_cookie_logged_invalid_generates_new(db):
1050 def TestRpc(request, context, session):
1051 return empty_pb2.Empty()
1053 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1054 call_rpc(empty_pb2.Empty(), metadata=(("cookie", "sofa=invalid-cookie-value"),))
1056 with session_scope() as session:
1057 trace = session.execute(select(APICall)).scalar_one()
1058 assert trace.sofa is not None
1059 assert trace.sofa != "invalid-cookie-value"
1060 assert len(trace.sofa) > 20
1063def test_sofa_cookie_with_authenticated_user(db):
1064 user, token = generate_user()
1065 sofa_value, cookie_string = generate_sofa_cookie()
1066 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1068 account = Account()
1070 rpc_def = {
1071 "rpc": account.GetAccountInfo,
1072 "service_name": "org.couchers.api.account.Account",
1073 "method_name": "GetAccountInfo",
1074 "interceptors": [CouchersMiddlewareInterceptor()],
1075 "request_type": empty_pb2.Empty,
1076 "response_type": account_pb2.GetAccountInfoRes,
1077 }
1079 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
1080 res = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}; sofa={cookie_value}"),))
1081 assert res.username == user.username
1083 with session_scope() as session:
1084 trace = session.execute(select(APICall)).scalar_one()
1085 assert trace.user_id == user.id
1086 assert trace.sofa == sofa_value
1089def test_sofa_cookie_persists_on_exception(db):
1090 sofa_value, cookie_string = generate_sofa_cookie()
1091 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1093 def TestRpc(request, context, session):
1094 raise Exception("Test error")
1096 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1097 with pytest.raises(Exception, match="Test error"):
1098 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1100 with session_scope() as session:
1101 trace = session.execute(select(APICall)).scalar_one()
1102 assert trace.sofa == sofa_value
1103 assert trace.traceback is not None
1104 assert "Test error" in trace.traceback