Coverage for app / backend / src / tests / test_interceptors.py: 100%
584 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
1from collections.abc import Callable, Generator
2from concurrent import futures
3from contextlib import contextmanager
4from datetime import timedelta
5from typing import Any
6from unittest.mock import Mock
8import grpc
9import pytest
10from google.protobuf import empty_pb2
11from google.protobuf.descriptor import ServiceDescriptor
12from google.protobuf.descriptor_pool import DescriptorPool
13from sqlalchemy import select, update
15from couchers.constants import (
16 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
17 NONEXISTENT_API_CALL_ERROR_MESSAGE,
18)
19from couchers.crypto import b64encode, random_hex, simple_encrypt
20from couchers.db import session_scope
21from couchers.descriptor_pool import get_descriptor_pool
22from couchers.interceptors import (
23 AbortError,
24 BadHeaders,
25 CouchersMiddlewareInterceptor,
26 ErrorSanitizationInterceptor,
27 UserAuthInfo,
28 check_permissions,
29 find_auth_level,
30 parse_headers,
31 validate_auth_level,
32)
33from couchers.metrics import servicer_duration_histogram
34from couchers.models import APICall, User, UserActivity, UserSession
35from couchers.proto import account_pb2, admin_pb2, annotations_pb2, api_pb2, auth_pb2
36from couchers.servicers.account import Account
37from couchers.servicers.api import API
38from couchers.utils import generate_sofa_cookie, now, parse_sofa_cookie
39from tests.fixtures.db import generate_user
40from tests.fixtures.sessions import real_admin_session
43@pytest.fixture(autouse=True)
44def _(testconfig):
45 pass
48@contextmanager
49def interceptor_dummy_api(
50 rpc,
51 interceptors,
52 service_name="org.couchers.auth.Auth",
53 method_name="SignupFlow",
54 request_type=empty_pb2.Empty,
55 response_type=empty_pb2.Empty,
56 creds=None,
57) -> Generator[Callable[..., Any]]:
58 with futures.ThreadPoolExecutor(1) as executor:
59 server = grpc.server(executor, interceptors=interceptors)
60 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
62 # manually add the handler
63 rpc_method_handlers = {
64 method_name: grpc.unary_unary_rpc_method_handler(
65 rpc,
66 request_deserializer=request_type.FromString,
67 response_serializer=response_type.SerializeToString,
68 )
69 }
70 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
71 server.add_generic_rpc_handlers((generic_handler,))
72 server.start()
74 try:
75 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
76 yield channel.unary_unary(
77 f"/{service_name}/{method_name}",
78 request_serializer=request_type.SerializeToString,
79 response_deserializer=response_type.FromString,
80 )
81 finally:
82 server.stop(None).wait()
85def _get_histogram_labels_value(method, logged_in, exception, code):
86 metrics = servicer_duration_histogram.collect()
87 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
88 histogram_counts = [
89 s
90 for s in servicer_histogram.samples
91 if s.name == "couchers_servicer_duration_seconds_count"
92 and s.labels["method"] == method
93 and s.labels["logged_in"] == logged_in
94 and s.labels["code"] == code
95 and s.labels["exception"] == exception
96 ]
97 if len(histogram_counts) == 0:
98 return 0
99 return histogram_counts[0].value
102def test_logging_interceptor_ok():
103 def TestRpc(request, context):
104 return empty_pb2.Empty()
106 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
107 call_rpc(empty_pb2.Empty())
110def test_logging_interceptor_all_ignored():
111 # error codes that should not be touched by the interceptor
112 pass_through_status_codes = [
113 # we can't abort with OK
114 # grpc.StatusCode.OK,
115 grpc.StatusCode.CANCELLED,
116 grpc.StatusCode.UNKNOWN,
117 grpc.StatusCode.INVALID_ARGUMENT,
118 grpc.StatusCode.DEADLINE_EXCEEDED,
119 grpc.StatusCode.NOT_FOUND,
120 grpc.StatusCode.ALREADY_EXISTS,
121 grpc.StatusCode.PERMISSION_DENIED,
122 grpc.StatusCode.UNAUTHENTICATED,
123 grpc.StatusCode.RESOURCE_EXHAUSTED,
124 grpc.StatusCode.FAILED_PRECONDITION,
125 grpc.StatusCode.ABORTED,
126 grpc.StatusCode.OUT_OF_RANGE,
127 grpc.StatusCode.UNIMPLEMENTED,
128 grpc.StatusCode.INTERNAL,
129 grpc.StatusCode.UNAVAILABLE,
130 grpc.StatusCode.DATA_LOSS,
131 ]
133 for status_code in pass_through_status_codes:
134 message = random_hex()
136 def TestRpc(request, context):
137 context.abort(status_code, message) # noqa: B023
139 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
140 with pytest.raises(grpc.RpcError) as e:
141 call_rpc(empty_pb2.Empty())
142 assert e.value.code() == status_code
143 assert e.value.details() == message
146def test_logging_interceptor_assertion():
147 def TestRpc(request, context):
148 raise AssertionError()
150 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
151 with pytest.raises(grpc.RpcError) as e:
152 call_rpc(empty_pb2.Empty())
153 assert e.value.code() == grpc.StatusCode.INTERNAL
154 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
157def test_logging_interceptor_div0():
158 def TestRpc(request, context):
159 1 / 0 # noqa: B018
161 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
162 with pytest.raises(grpc.RpcError) as e:
163 call_rpc(empty_pb2.Empty())
164 assert e.value.code() == grpc.StatusCode.INTERNAL
165 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
168def test_logging_interceptor_raise():
169 def TestRpc(request, context):
170 raise Exception()
172 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
173 with pytest.raises(grpc.RpcError) as e:
174 call_rpc(empty_pb2.Empty())
175 assert e.value.code() == grpc.StatusCode.INTERNAL
176 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
179def test_logging_interceptor_raise_custom():
180 class _TestingException(Exception):
181 pass
183 def TestRpc(request, context):
184 raise _TestingException("This is a custom exception")
186 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
187 with pytest.raises(grpc.RpcError) as e:
188 call_rpc(empty_pb2.Empty())
189 assert e.value.code() == grpc.StatusCode.INTERNAL
190 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
193def test_tracing_interceptor_ok_open(db):
194 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
196 def TestRpc(request, context, session):
197 return empty_pb2.Empty()
199 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
200 call_rpc(empty_pb2.Empty())
202 with session_scope() as session:
203 trace = session.execute(select(APICall)).scalar_one()
204 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
205 assert not trace.status_code
206 assert not trace.user_id
207 assert trace.request is not None
208 assert len(trace.request) == 0
209 assert trace.response is not None
210 assert len(trace.response) == 0
211 assert not trace.traceback
213 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
216def test_tracing_interceptor_sensitive(db):
217 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
219 def TestRpc(request, context, session):
220 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
222 with interceptor_dummy_api(
223 TestRpc,
224 interceptors=[CouchersMiddlewareInterceptor()],
225 request_type=auth_pb2.SignupFlowReq,
226 response_type=auth_pb2.AuthReq,
227 ) as call_rpc:
228 call_rpc(
229 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
230 )
232 with session_scope() as session:
233 trace = session.execute(select(APICall)).scalar_one()
234 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
235 assert not trace.status_code
236 assert not trace.user_id
237 assert not trace.traceback
238 assert trace.request is not None
239 req = auth_pb2.SignupFlowReq.FromString(trace.request)
240 assert not req.account.password
241 assert req.account.username == "not removed"
242 assert trace.response
243 res = auth_pb2.AuthReq.FromString(trace.response)
244 assert res.user == "this is not secret"
245 assert not res.password
247 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
250def test_tracing_interceptor_sensitive_ping(db):
251 user, token = generate_user()
253 with interceptor_dummy_api(
254 API().GetUser,
255 interceptors=[CouchersMiddlewareInterceptor()],
256 request_type=api_pb2.GetUserReq,
257 response_type=api_pb2.User,
258 service_name="org.couchers.api.core.API",
259 method_name="GetUser",
260 ) as call_rpc:
261 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
264def test_tracing_interceptor_exception(db):
265 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "")
267 def TestRpc(request, context, session):
268 raise Exception("Some error message")
270 with interceptor_dummy_api(
271 TestRpc,
272 interceptors=[CouchersMiddlewareInterceptor()],
273 request_type=auth_pb2.SignupAccount,
274 response_type=auth_pb2.AuthReq,
275 ) as call_rpc:
276 with pytest.raises(Exception, match="Some error message"):
277 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
279 with session_scope() as session:
280 trace = session.execute(select(APICall)).scalar_one()
281 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
282 assert not trace.status_code
283 assert not trace.user_id
284 assert trace.traceback
285 assert "Some error message" in trace.traceback
286 assert trace.request is not None
287 req = auth_pb2.SignupAccount.FromString(trace.request)
288 assert not req.password
289 assert req.username == "not removed"
290 assert not trace.response
292 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1
295def test_tracing_interceptor_abort(db):
296 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
298 def TestRpc(request, context, session):
299 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
301 with interceptor_dummy_api(
302 TestRpc,
303 interceptors=[CouchersMiddlewareInterceptor()],
304 request_type=auth_pb2.SignupAccount,
305 response_type=auth_pb2.AuthReq,
306 ) as call_rpc:
307 with pytest.raises(Exception, match="now a grpc abort"):
308 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
310 with session_scope() as session:
311 trace = session.execute(select(APICall)).scalar_one()
312 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
313 assert trace.status_code == "FAILED_PRECONDITION"
314 assert not trace.user_id
315 assert trace.traceback
316 assert "now a grpc abort" in trace.traceback
317 assert trace.request is not None
318 req = auth_pb2.SignupAccount.FromString(trace.request)
319 assert not req.password
320 assert req.username == "not removed"
321 assert not trace.response
323 assert (
324 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
325 == val + 1
326 )
329def cookie_auth(token: str) -> tuple[str, str]:
330 return "cookie", f"couchers-sesh={token}"
333def api_auth(token: str) -> tuple[str, str]:
334 return "authorization", f"Bearer {token}"
337def test_auth_interceptor(db):
338 super_user, super_token = generate_user(is_superuser=True)
339 user, token = generate_user()
340 deleted_user, deleted_token = generate_user(delete_user=True)
342 with real_admin_session(super_token) as api:
343 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
345 with session_scope() as session:
346 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
348 account = Account()
350 rpc_def = {
351 "rpc": account.GetAccountInfo,
352 "service_name": "org.couchers.api.account.Account",
353 "method_name": "GetAccountInfo",
354 "interceptors": [CouchersMiddlewareInterceptor()],
355 "request_type": empty_pb2.Empty,
356 "response_type": account_pb2.GetAccountInfoRes,
357 }
359 # no creds, no-go for secure APIs
360 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
361 with pytest.raises(grpc.RpcError) as e:
362 call_rpc(empty_pb2.Empty())
363 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
364 assert e.value.details() == "Unauthorized"
366 # can auth with cookie
367 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
368 res1 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
369 assert res1.username == user.username
371 with session_scope() as session:
372 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
373 assert api_calls == 1
375 # can't auth with a wrong cookie
376 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
377 with pytest.raises(grpc.RpcError) as e:
378 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(random_hex(32)),))
379 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
380 assert e.value.details() == "Unauthorized"
382 # can auth with an api key
383 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
384 res2 = call_rpc(empty_pb2.Empty(), metadata=(api_auth(api_key),))
385 assert res2.username == user.username
387 with session_scope() as session:
388 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
389 assert api_calls == 2
391 # can't auth with a wrong api key
392 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
393 with pytest.raises(grpc.RpcError) as e:
394 call_rpc(empty_pb2.Empty(), metadata=(api_auth(random_hex(32)),))
395 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
396 assert e.value.details() == "Unauthorized"
398 # can auth with grpc helper (they do the same as above)
399 comp_creds = grpc.composite_channel_credentials(
400 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
401 )
402 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
403 res3 = call_rpc(empty_pb2.Empty())
404 assert res3.username == user.username
406 # can't auth with both
407 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
408 with pytest.raises(grpc.RpcError) as e:
409 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), api_auth(api_key)))
410 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
411 assert e.value.details() == 'Both "cookie" and "authorization" in request'
413 # malformed bearer
414 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
415 with pytest.raises(grpc.RpcError) as e:
416 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
417 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
418 assert e.value.details() == "Unauthorized"
420 # Invisible (deleted) user
421 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
422 with pytest.raises(grpc.RpcError) as e:
423 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(deleted_token),))
424 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
425 assert e.value.details() == "Unauthorized"
427 # Invalid (expired) session
428 long_ago = now() - timedelta(weeks=100)
429 with session_scope() as session:
430 session.execute(update(UserSession).values(last_seen=long_ago).where(UserSession.token == token))
432 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
433 with pytest.raises(grpc.RpcError) as e:
434 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
435 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
436 assert e.value.details() == "Unauthorized"
438 # API key token, but session is for session cookie (probably impossible, but...)
439 with session_scope() as session:
440 session.execute(update(UserSession).values(last_seen=now(), is_api_key=True).where(UserSession.token == token))
442 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
443 with pytest.raises(grpc.RpcError) as e:
444 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
445 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
446 assert e.value.details() == "Unauthorized"
448 # Check that metadata are updated
449 six_minutes_ago = now() - timedelta(minutes=6)
450 with session_scope() as session:
451 # Return the session to normal
452 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
453 user_session.is_api_key = False
454 api_calls = user_session.api_calls
456 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
457 res4 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
458 assert res4.username == user.username
460 with session_scope() as session:
461 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
462 assert user_session.api_calls == api_calls + 1
463 assert user_session.last_seen > now() - timedelta(seconds=1)
465 # Simulate user inactivity, so last_active is updated on the next api call.
466 session.execute(update(User).values(last_active=six_minutes_ago).where(User.id == user.id))
468 # Check that last_active is updated if it wasn't updated in a while.
469 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
470 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
472 with session_scope() as session:
473 last_active = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
474 assert last_active > now() - timedelta(seconds=1)
476 # Check that last_active is untouched (since it was already updated recently)
477 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
478 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
480 with session_scope() as session:
481 last_active_2 = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
482 assert last_active_2 == last_active
484 # Check that activity is split by IP.
485 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
486 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), ("x-couchers-real-ip", "1.1.1.1")))
488 with session_scope() as session:
489 api_calls = session.execute(
490 select(UserActivity.api_calls).where(UserActivity.ip_address == "1.1.1.1")
491 ).scalar_one()
492 assert api_calls == 1
494 # Check that activity is split in time bins.
495 # Update all UserActivity to be in the far past so that a new row is inserted on the next request.
496 with session_scope() as session:
497 session.execute(update(UserActivity).values(period=long_ago).where(UserActivity.user_id == user.id))
499 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
500 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
502 with session_scope() as session:
503 api_calls = session.execute(
504 select(UserActivity.api_calls)
505 .where(UserActivity.user_id == user.id)
506 .order_by(UserActivity.id.desc())
507 .limit(1)
508 ).scalar_one()
509 assert api_calls == 1
512def test_tracing_interceptor_auth_cookies(db):
513 user, token = generate_user()
515 account = Account()
517 rpc_def = {
518 "rpc": account.GetAccountInfo,
519 "service_name": "org.couchers.api.account.Account",
520 "method_name": "GetAccountInfo",
521 "interceptors": [CouchersMiddlewareInterceptor()],
522 "request_type": empty_pb2.Empty,
523 "response_type": account_pb2.GetAccountInfoRes,
524 }
526 # with cookies
527 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
528 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
529 assert res1.username == user.username
531 with session_scope() as session:
532 trace = session.execute(select(APICall)).scalar_one()
533 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
534 assert not trace.status_code
535 assert trace.user_id == user.id
536 assert not trace.is_api_key
537 assert trace.request is not None
538 assert len(trace.request) == 0
539 assert not trace.traceback
542def test_tracing_interceptor_auth_api_key(db):
543 super_user, super_token = generate_user(is_superuser=True)
544 user, token = generate_user()
546 with real_admin_session(super_token) as api:
547 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
549 with session_scope() as session:
550 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
552 account = Account()
554 rpc_def = {
555 "rpc": account.GetAccountInfo,
556 "service_name": "org.couchers.api.account.Account",
557 "method_name": "GetAccountInfo",
558 "interceptors": [CouchersMiddlewareInterceptor()],
559 "request_type": empty_pb2.Empty,
560 "response_type": account_pb2.GetAccountInfoRes,
561 }
563 # with api key
564 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
565 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
566 assert res1.username == user.username
568 with session_scope() as session:
569 trace = session.execute(
570 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo")
571 ).scalar_one()
572 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
573 assert not trace.status_code
574 assert trace.user_id == user.id
575 assert trace.is_api_key
576 assert trace.request is not None
577 assert len(trace.request) == 0
578 assert not trace.traceback
581def test_auth_levels(db):
582 def TestRpc(request, context, session):
583 return empty_pb2.Empty()
585 def gen_args(service, method):
586 return {
587 "rpc": TestRpc,
588 "service_name": service,
589 "method_name": method,
590 "interceptors": [CouchersMiddlewareInterceptor()],
591 "request_type": empty_pb2.Empty,
592 "response_type": empty_pb2.Empty,
593 }
595 # superuser (note: superusers are automatically editors due to DB constraint)
596 _, super_token = generate_user(is_superuser=True)
597 # editor user
598 _, editor_token = generate_user(is_editor=True)
599 # normal user
600 _, normal_token = generate_user()
601 # jailed user
602 _, jailed_token = generate_user(accepted_tos=0)
603 # open user
604 open_token = ""
606 # pick some rpcs here with the right auth levels
607 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
608 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
609 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
610 editor_args = gen_args("org.couchers.editor.Editor", "CreateCommunity")
611 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
613 # pairs to check
614 checks = [
615 # name, args, token, works?, code, message
616 # open token only works on open servicers
617 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
618 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
619 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
620 ("open x editor", open_token, editor_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
621 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
622 # jailed works on jailed and open
623 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
624 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
625 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
626 ("jailed x editor", jailed_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
627 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
628 # normal works on all but editor and admin
629 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
630 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
631 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
632 ("normal x editor", normal_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
633 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
634 # editor works on all but admin
635 ("editor x open", editor_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
636 ("editor x jailed", editor_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
637 ("editor x secure", editor_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
638 ("editor x editor", editor_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
639 ("editor x admin", editor_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
640 # superuser works on all
641 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
642 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
643 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
644 ("super x editor", super_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
645 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
646 ]
648 for name, token, args, should_work, code, message in checks:
649 print(f"Testing (token x args) = ({name}), {should_work=}")
650 metadata = (("cookie", f"couchers-sesh={token}"),)
651 with interceptor_dummy_api(**args) as call_rpc:
652 if should_work:
653 call_rpc(empty_pb2.Empty(), metadata=metadata)
654 else:
655 with pytest.raises(grpc.RpcError) as err:
656 call_rpc(empty_pb2.Empty(), metadata=metadata)
657 assert err.value.code() == code
658 assert err.value.details() == message
660 # a non-existent RPC
661 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
663 with interceptor_dummy_api(**nonexistent) as call_rpc:
664 with pytest.raises(grpc.RpcError) as err:
665 call_rpc(empty_pb2.Empty())
666 assert err.value.code() == grpc.StatusCode.UNIMPLEMENTED
667 assert err.value.details() == "API call does not exist. Please refresh and try again."
669 # an RPC without a service level
670 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
672 with interceptor_dummy_api(**invalid_args) as call_rpc:
673 with pytest.raises(grpc.RpcError) as err:
674 call_rpc(empty_pb2.Empty())
675 assert err.value.code() == grpc.StatusCode.INTERNAL
676 assert err.value.details() == "Internal authentication error."
679def test_parse_headers_with_session_cookie():
680 headers = {"cookie": "couchers-sesh=abc123; other-cookie=value"}
681 result = parse_headers(headers)
682 assert result.token == "abc123"
683 assert result.is_api_key is False
686def test_parse_headers_with_authorization_header():
687 headers = {"authorization": "Bearer abc123"}
688 result = parse_headers(headers)
689 assert result.token == "abc123"
690 assert result.is_api_key is True
693def test_parse_headers_with_both_cookie_and_authorization():
694 headers = {"cookie": "couchers-sesh=abc123", "authorization": "Bearer xyz789"}
695 with pytest.raises(BadHeaders, match="Both cookies and authorization are present in headers"):
696 parse_headers(headers)
699def test_parse_headers_with_neither_cookie_nor_authorization():
700 result = parse_headers({})
701 assert result.token is None
702 assert result.is_api_key is False
705def test_parse_headers_with_all_optional_headers():
706 headers = {
707 "cookie": "couchers-sesh=abc123; couchers-user-id=42; NEXT_LOCALE=en",
708 "x-couchers-real-ip": "192.168.1.1",
709 "user-agent": "TestAgent/1.0",
710 }
711 result = parse_headers(headers)
712 assert result.token == "abc123"
713 assert result.is_api_key is False
714 assert result.ip_address == "192.168.1.1"
715 assert result.user_agent == "TestAgent/1.0"
716 assert result.ui_lang == "en"
717 assert result.user_id == "42"
720def test_parse_headers_with_bytes_ip_address():
721 headers: dict[str, str | bytes] = {
722 "cookie": "couchers-sesh=abc123",
723 "x-couchers-real-ip": b"192.168.1.1",
724 }
725 result = parse_headers(headers)
726 assert result.ip_address is None
729def test_parse_headers_with_bytes_user_agent():
730 headers: dict[str, str | bytes] = {
731 "cookie": "couchers-sesh=abc123",
732 "user-agent": b"TestAgent/1.0",
733 }
734 result = parse_headers(headers)
735 assert result.user_agent is None
738def test_parse_headers_malformed_authorization():
739 headers = {"authorization": "bearer abc123"}
740 result = parse_headers(headers)
741 assert result.token is None
742 assert result.is_api_key is True
745def test_find_auth_level_with_valid_service():
746 pool = get_descriptor_pool()
748 result = find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
749 assert result == annotations_pb2.AUTH_LEVEL_SECURE
752def test_find_auth_level_with_nonexistent_service():
753 pool = get_descriptor_pool()
755 with pytest.raises(AbortError) as exc:
756 find_auth_level(pool, "/org.couchers.nonexistent.Service/Method")
757 assert exc.value.msg == NONEXISTENT_API_CALL_ERROR_MESSAGE
758 assert exc.value.code == grpc.StatusCode.UNIMPLEMENTED
761def test_find_auth_level_with_unknown_auth_level():
762 pool = Mock(spec=DescriptorPool)
763 service_desc = Mock(spec=ServiceDescriptor)
764 service_options = Mock()
765 service_options.Extensions = {annotations_pb2.auth_level: annotations_pb2.AUTH_LEVEL_UNKNOWN}
766 service_desc.GetOptions.return_value = service_options
767 pool.FindServiceByName.return_value = service_desc
769 with pytest.raises(AbortError) as exc:
770 find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
771 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
772 assert exc.value.code == grpc.StatusCode.INTERNAL
775def test_validate_auth_level_with_unknown():
776 with pytest.raises(AbortError) as exc:
777 validate_auth_level(annotations_pb2.AUTH_LEVEL_UNKNOWN)
778 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
779 assert exc.value.code == grpc.StatusCode.INTERNAL
782def test_validate_auth_level_with_open():
783 validate_auth_level(annotations_pb2.AUTH_LEVEL_OPEN)
786def test_validate_auth_level_with_jailed():
787 validate_auth_level(annotations_pb2.AUTH_LEVEL_JAILED)
790def test_validate_auth_level_with_secure():
791 validate_auth_level(annotations_pb2.AUTH_LEVEL_SECURE)
794def test_validate_auth_level_with_editor():
795 validate_auth_level(annotations_pb2.AUTH_LEVEL_EDITOR)
798def test_validate_auth_level_with_admin():
799 validate_auth_level(annotations_pb2.AUTH_LEVEL_ADMIN)
802def test_check_auth_open_service_without_auth():
803 check_permissions(None, annotations_pb2.AUTH_LEVEL_OPEN)
806def test_check_auth_open_service_with_auth():
807 auth_info = UserAuthInfo(
808 user_id=1,
809 is_jailed=False,
810 is_editor=False,
811 is_superuser=False,
812 token_expiry=now(),
813 ui_language_preference="en",
814 token="abc123",
815 is_api_key=False,
816 )
817 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_OPEN)
820def test_check_auth_secure_service_without_auth():
821 with pytest.raises(AbortError):
822 check_permissions(None, annotations_pb2.AUTH_LEVEL_SECURE)
825def test_check_auth_secure_service_with_normal_auth():
826 auth_info = UserAuthInfo(
827 user_id=1,
828 is_jailed=False,
829 is_editor=False,
830 is_superuser=False,
831 token_expiry=now(),
832 ui_language_preference="en",
833 token="abc123",
834 is_api_key=False,
835 )
836 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
839def test_check_auth_secure_service_with_jailed_user():
840 auth_info = UserAuthInfo(
841 user_id=1,
842 is_jailed=True,
843 is_editor=False,
844 is_superuser=False,
845 token_expiry=now(),
846 ui_language_preference="en",
847 token="abc123",
848 is_api_key=False,
849 )
850 with pytest.raises(AbortError):
851 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
854def test_check_auth_jailed_service_with_jailed_user():
855 auth_info = UserAuthInfo(
856 user_id=1,
857 is_jailed=True,
858 is_editor=False,
859 is_superuser=False,
860 token_expiry=now(),
861 ui_language_preference="en",
862 token="abc123",
863 is_api_key=False,
864 )
865 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_JAILED)
868def test_check_auth_jailed_service_without_auth():
869 with pytest.raises(AbortError):
870 check_permissions(None, annotations_pb2.AUTH_LEVEL_JAILED)
873def test_check_auth_editor_service_without_editor():
874 auth_info = UserAuthInfo(
875 user_id=1,
876 is_jailed=False,
877 is_editor=False,
878 is_superuser=False,
879 token_expiry=now(),
880 ui_language_preference="en",
881 token="abc123",
882 is_api_key=False,
883 )
884 with pytest.raises(AbortError):
885 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
888def test_check_auth_editor_service_with_editor():
889 auth_info = UserAuthInfo(
890 user_id=1,
891 is_jailed=False,
892 is_editor=True,
893 is_superuser=False,
894 token_expiry=now(),
895 ui_language_preference="en",
896 token="abc123",
897 is_api_key=False,
898 )
899 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
902def test_check_auth_admin_service_without_superuser():
903 auth_info = UserAuthInfo(
904 user_id=1,
905 is_jailed=False,
906 is_editor=True,
907 is_superuser=False,
908 token_expiry=now(),
909 ui_language_preference="en",
910 token="abc123",
911 is_api_key=False,
912 )
913 with pytest.raises(AbortError):
914 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
917def test_check_auth_admin_service_with_superuser():
918 auth_info = UserAuthInfo(
919 user_id=1,
920 is_jailed=False,
921 is_editor=True,
922 is_superuser=True,
923 token_expiry=now(),
924 ui_language_preference="en",
925 token="abc123",
926 is_api_key=False,
927 )
928 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
931def test_check_auth_admin_service_without_auth():
932 with pytest.raises(AbortError):
933 check_permissions(None, annotations_pb2.AUTH_LEVEL_ADMIN)
936def test_parse_sofa_cookie_valid():
937 sofa_value, cookie_string = generate_sofa_cookie()
938 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
940 headers = {"cookie": f"sofa={cookie_value}"}
941 result = parse_sofa_cookie(headers)
942 assert result == sofa_value
945def test_parse_sofa_cookie_missing():
946 headers = {"cookie": "other-cookie=value"}
947 result = parse_sofa_cookie(headers)
948 assert result is None
951def test_parse_sofa_cookie_no_cookies():
952 headers = {}
953 result = parse_sofa_cookie(headers)
954 assert result is None
957def test_parse_sofa_cookie_invalid_base64():
958 headers = {"cookie": "sofa=not-valid-base64!!!"}
959 result = parse_sofa_cookie(headers)
960 assert result is None
963def test_parse_sofa_cookie_invalid_encryption():
964 headers = {"cookie": f"sofa={b64encode(b'invalid encrypted data')}"}
965 result = parse_sofa_cookie(headers)
966 assert result is None
969def test_parse_sofa_cookie_invalid_proto():
970 encrypted = simple_encrypt("sofa_cookie", b"not a valid proto")
971 headers = {"cookie": f"sofa={b64encode(encrypted)}"}
972 result = parse_sofa_cookie(headers)
973 assert result is not None or result is None
976def test_generate_sofa_cookie():
977 sofa_value, cookie_string = generate_sofa_cookie()
979 assert sofa_value
980 assert isinstance(sofa_value, str)
981 assert len(sofa_value) > 20
983 assert "sofa=" in cookie_string
984 assert "expires=" in cookie_string.lower()
986 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
987 headers = {"cookie": f"sofa={cookie_value}"}
988 parsed_value = parse_sofa_cookie(headers)
989 assert parsed_value == sofa_value
992def test_parse_headers_with_sofa_cookie():
993 sofa_value, cookie_string = generate_sofa_cookie()
994 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
996 headers = {
997 "cookie": f"couchers-sesh=abc123; sofa={cookie_value}",
998 }
999 result = parse_headers(headers)
1000 assert result.token == "abc123"
1001 assert result.sofa == sofa_value
1004def test_parse_headers_without_sofa_cookie():
1005 headers = {
1006 "cookie": "couchers-sesh=abc123",
1007 }
1008 result = parse_headers(headers)
1009 assert result.token == "abc123"
1010 assert result.sofa is None
1013def test_sofa_cookie_logged_new(db):
1014 def TestRpc(request, context, session):
1015 return empty_pb2.Empty()
1017 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1018 call_rpc(empty_pb2.Empty())
1020 with session_scope() as session:
1021 trace = session.execute(select(APICall)).scalar_one()
1022 assert trace.sofa is not None
1023 assert len(trace.sofa) > 20
1026def test_sofa_cookie_logged_existing(db):
1027 sofa_value, cookie_string = generate_sofa_cookie()
1028 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1030 def TestRpc(request, context, session):
1031 return empty_pb2.Empty()
1033 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1034 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1036 with session_scope() as session:
1037 trace = session.execute(select(APICall)).scalar_one()
1038 assert trace.sofa == sofa_value
1041def test_sofa_cookie_logged_invalid_generates_new(db):
1042 def TestRpc(request, context, session):
1043 return empty_pb2.Empty()
1045 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1046 call_rpc(empty_pb2.Empty(), metadata=(("cookie", "sofa=invalid-cookie-value"),))
1048 with session_scope() as session:
1049 trace = session.execute(select(APICall)).scalar_one()
1050 assert trace.sofa is not None
1051 assert trace.sofa != "invalid-cookie-value"
1052 assert len(trace.sofa) > 20
1055def test_sofa_cookie_with_authenticated_user(db):
1056 user, token = generate_user()
1057 sofa_value, cookie_string = generate_sofa_cookie()
1058 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1060 account = Account()
1062 rpc_def = {
1063 "rpc": account.GetAccountInfo,
1064 "service_name": "org.couchers.api.account.Account",
1065 "method_name": "GetAccountInfo",
1066 "interceptors": [CouchersMiddlewareInterceptor()],
1067 "request_type": empty_pb2.Empty,
1068 "response_type": account_pb2.GetAccountInfoRes,
1069 }
1071 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
1072 res = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}; sofa={cookie_value}"),))
1073 assert res.username == user.username
1075 with session_scope() as session:
1076 trace = session.execute(select(APICall)).scalar_one()
1077 assert trace.user_id == user.id
1078 assert trace.sofa == sofa_value
1081def test_sofa_cookie_persists_on_exception(db):
1082 sofa_value, cookie_string = generate_sofa_cookie()
1083 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1085 def TestRpc(request, context, session):
1086 raise Exception("Test error")
1088 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1089 with pytest.raises(Exception, match="Test error"):
1090 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1092 with session_scope() as session:
1093 trace = session.execute(select(APICall)).scalar_one()
1094 assert trace.sofa == sofa_value
1095 assert trace.traceback is not None
1096 assert "Test error" in trace.traceback