Coverage for app/backend/src/tests/test_interceptors.py: 99%
657 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
1from collections.abc import Callable, Generator
2from concurrent import futures
3from contextlib import contextmanager
4from datetime import timedelta
5from typing import Any
6from unittest.mock import Mock, patch
8import grpc
9import pytest
10from google.protobuf import empty_pb2
11from google.protobuf.descriptor import ServiceDescriptor
12from google.protobuf.descriptor_pool import DescriptorPool
13from sqlalchemy import select, text, update
15from couchers.constants import (
16 MISSING_AUTH_LEVEL_ERROR_MESSAGE,
17 NONEXISTENT_API_CALL_ERROR_MESSAGE,
18 UNKNOWN_ERROR_MESSAGE,
19)
20from couchers.crypto import b64encode, random_hex, simple_encrypt
21from couchers.db import session_scope
22from couchers.descriptor_pool import get_descriptor_pool
23from couchers.interceptors import (
24 AbortError,
25 BadHeaders,
26 CouchersMiddlewareInterceptor,
27 ErrorSanitizationInterceptor,
28 UserAuthInfo,
29 check_permissions,
30 find_auth_level,
31 parse_headers,
32 validate_auth_level,
33)
34from couchers.metrics import (
35 api_calls_counter,
36 servicer_db_query_count_histogram,
37 servicer_duration_histogram,
38 servicer_pool_wait_histogram,
39 servicer_serde_histogram,
40 servicer_setup_cpu_time_histogram,
41 servicer_setup_db_time_histogram,
42 servicer_setup_errors_counter,
43)
44from couchers.models import APICall, ClientPlatform, User, UserActivity, UserSession
45from couchers.proto import account_pb2, admin_pb2, annotations_pb2, api_pb2, auth_pb2
46from couchers.servicers.account import Account
47from couchers.servicers.api import API
48from couchers.utils import generate_sofa_cookie, now, parse_sofa_cookie
49from tests.fixtures.db import generate_user
50from tests.fixtures.sessions import real_admin_session
53@pytest.fixture(autouse=True)
54def _(testconfig):
55 pass
58@contextmanager
59def interceptor_dummy_api(
60 rpc,
61 interceptors,
62 service_name="org.couchers.auth.Auth",
63 method_name="SignupFlow",
64 request_type=empty_pb2.Empty,
65 response_type=empty_pb2.Empty,
66 creds=None,
67) -> Generator[Callable[..., Any]]:
68 with futures.ThreadPoolExecutor(1) as executor:
69 server = grpc.server(executor, interceptors=interceptors)
70 port = server.add_secure_port("localhost:0", grpc.local_server_credentials())
72 # manually add the handler
73 rpc_method_handlers = {
74 method_name: grpc.unary_unary_rpc_method_handler(
75 rpc,
76 request_deserializer=request_type.FromString,
77 response_serializer=response_type.SerializeToString,
78 )
79 }
80 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers)
81 server.add_generic_rpc_handlers((generic_handler,))
82 server.start()
84 try:
85 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel:
86 yield channel.unary_unary(
87 f"/{service_name}/{method_name}",
88 request_serializer=request_type.SerializeToString,
89 response_deserializer=response_type.FromString,
90 )
91 finally:
92 server.stop(None).wait()
95def _get_histogram_labels_value(method, logged_in, exception, code):
96 metrics = servicer_duration_histogram.collect()
97 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0]
98 histogram_counts = [
99 s
100 for s in servicer_histogram.samples
101 if s.name == "couchers_servicer_duration_seconds_count"
102 and s.labels["method"] == method
103 and s.labels["logged_in"] == logged_in
104 and s.labels["code"] == code
105 and s.labels["exception"] == exception
106 ]
107 if len(histogram_counts) == 0:
108 return 0
109 return histogram_counts[0].value
112def _get_setup_errors_value(method, exception):
113 metrics = servicer_setup_errors_counter.collect()
114 counter = [m for m in metrics if m.name == "couchers_servicer_setup_errors"][0]
115 samples = [
116 s
117 for s in counter.samples
118 if s.name == "couchers_servicer_setup_errors_total"
119 and s.labels["method"] == method
120 and s.labels["exception"] == exception
121 ]
122 if len(samples) == 0:
123 return 0
124 return samples[0].value
127def test_logging_interceptor_ok():
128 def TestRpc(request, context):
129 return empty_pb2.Empty()
131 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
132 call_rpc(empty_pb2.Empty())
135def test_logging_interceptor_all_ignored():
136 # error codes that should not be touched by the interceptor
137 pass_through_status_codes = [
138 # we can't abort with OK
139 # grpc.StatusCode.OK,
140 grpc.StatusCode.CANCELLED,
141 grpc.StatusCode.UNKNOWN,
142 grpc.StatusCode.INVALID_ARGUMENT,
143 grpc.StatusCode.DEADLINE_EXCEEDED,
144 grpc.StatusCode.NOT_FOUND,
145 grpc.StatusCode.ALREADY_EXISTS,
146 grpc.StatusCode.PERMISSION_DENIED,
147 grpc.StatusCode.UNAUTHENTICATED,
148 grpc.StatusCode.RESOURCE_EXHAUSTED,
149 grpc.StatusCode.FAILED_PRECONDITION,
150 grpc.StatusCode.ABORTED,
151 grpc.StatusCode.OUT_OF_RANGE,
152 grpc.StatusCode.UNIMPLEMENTED,
153 grpc.StatusCode.INTERNAL,
154 grpc.StatusCode.UNAVAILABLE,
155 grpc.StatusCode.DATA_LOSS,
156 ]
158 for status_code in pass_through_status_codes:
159 message = random_hex()
161 def TestRpc(request, context):
162 context.abort(status_code, message) # noqa: B023
164 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
165 with pytest.raises(grpc.RpcError) as e:
166 call_rpc(empty_pb2.Empty())
167 assert e.value.code() == status_code
168 assert e.value.details() == message
171def test_logging_interceptor_assertion():
172 def TestRpc(request, context):
173 raise AssertionError()
175 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
176 with pytest.raises(grpc.RpcError) as e:
177 call_rpc(empty_pb2.Empty())
178 assert e.value.code() == grpc.StatusCode.INTERNAL
179 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
182def test_logging_interceptor_div0():
183 def TestRpc(request, context):
184 1 / 0 # noqa: B018
186 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
187 with pytest.raises(grpc.RpcError) as e:
188 call_rpc(empty_pb2.Empty())
189 assert e.value.code() == grpc.StatusCode.INTERNAL
190 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
193def test_logging_interceptor_raise():
194 def TestRpc(request, context):
195 raise Exception()
197 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
198 with pytest.raises(grpc.RpcError) as e:
199 call_rpc(empty_pb2.Empty())
200 assert e.value.code() == grpc.StatusCode.INTERNAL
201 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
204def test_logging_interceptor_raise_custom():
205 class _TestingException(Exception):
206 pass
208 def TestRpc(request, context):
209 raise _TestingException("This is a custom exception")
211 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc:
212 with pytest.raises(grpc.RpcError) as e:
213 call_rpc(empty_pb2.Empty())
214 assert e.value.code() == grpc.StatusCode.INTERNAL
215 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!"
218def test_tracing_interceptor_ok_open(db):
219 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
221 def TestRpc(request, context, session):
222 return empty_pb2.Empty()
224 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
225 call_rpc(empty_pb2.Empty())
227 with session_scope() as session:
228 trace = session.execute(select(APICall)).scalar_one()
229 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
230 assert not trace.status_code
231 assert not trace.user_id
232 assert trace.request is not None
233 assert len(trace.request) == 0
234 assert trace.response is not None
235 assert len(trace.response) == 0
236 assert not trace.traceback
238 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
241def _get_db_query_count_histogram(method):
242 return sum(
243 s.value
244 for m in servicer_db_query_count_histogram.collect()
245 for s in m.samples
246 if s.name == "couchers_servicer_db_query_count_count" and s.labels.get("method") == method
247 )
250def _get_api_call_count(method, platform):
251 return sum(
252 s.value
253 for m in api_calls_counter.collect()
254 for s in m.samples
255 if s.name == "couchers_api_calls_total"
256 and s.labels.get("method") == method
257 and s.labels.get("platform") == platform
258 )
261def test_tracing_interceptor_perf_accounting(db):
262 method = "/org.couchers.auth.Auth/SignupFlow"
263 hist_count_before = _get_db_query_count_histogram(method)
264 api_call_count_before = _get_api_call_count(method, "web_mobile")
266 # handler runs a known number of statements: three reads and one compiled write. The write matches zero rows so
267 # it's side-effect free.
268 def TestRpc(request, context, session):
269 for _ in range(3):
270 session.execute(text("SELECT 1"))
271 session.execute(update(APICall).where(APICall.id == -1).values(method="x"))
272 return empty_pb2.Empty()
274 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
275 call_rpc(empty_pb2.Empty(), metadata=(("x-couchers-client-platform", "web_mobile"),))
277 with session_scope() as session:
278 trace = session.execute(select(APICall)).scalar_one()
279 assert trace.db_query_count == 4
280 assert trace.db_write_query_count == 1
281 assert trace.db_time_ms is not None and trace.db_time_ms >= 0
282 assert trace.cpu_ms is not None and trace.cpu_ms >= 0
283 # the handler's DB work can't exceed the whole-request wall time
284 assert trace.db_time_ms <= trace.duration
285 assert trace.client_platform == ClientPlatform.web_mobile
287 # the call was also observed into the Prometheus per-request resource histograms and the per-platform call counter
288 assert _get_db_query_count_histogram(method) == hist_count_before + 1
289 assert _get_api_call_count(method, "web_mobile") == api_call_count_before + 1
292def _get_histogram_count(histogram, count_name, **labels):
293 return sum(
294 s.value
295 for m in histogram.collect()
296 for s in m.samples
297 if s.name == count_name and all(s.labels.get(k) == v for k, v in labels.items())
298 )
301def test_tracing_interceptor_phase_histograms(db):
302 # setup db/cpu, pool-wait, and de/serialization are each observed once per call into their own histogram
303 method = "/org.couchers.auth.Auth/SignupFlow"
304 setup_db_before = _get_histogram_count(
305 servicer_setup_db_time_histogram, "couchers_servicer_setup_db_time_seconds_count", method=method
306 )
307 setup_cpu_before = _get_histogram_count(
308 servicer_setup_cpu_time_histogram, "couchers_servicer_setup_cpu_seconds_count", method=method
309 )
310 pool_wait_before = _get_histogram_count(
311 servicer_pool_wait_histogram, "couchers_servicer_pool_wait_seconds_count", method=method
312 )
313 deserialize_before = _get_histogram_count(
314 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="deserialize"
315 )
316 serialize_before = _get_histogram_count(
317 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="serialize"
318 )
320 def TestRpc(request, context, session):
321 return empty_pb2.Empty()
323 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
324 call_rpc(empty_pb2.Empty())
326 assert (
327 _get_histogram_count(
328 servicer_setup_db_time_histogram, "couchers_servicer_setup_db_time_seconds_count", method=method
329 )
330 == setup_db_before + 1
331 )
332 assert (
333 _get_histogram_count(
334 servicer_setup_cpu_time_histogram, "couchers_servicer_setup_cpu_seconds_count", method=method
335 )
336 == setup_cpu_before + 1
337 )
338 assert (
339 _get_histogram_count(servicer_pool_wait_histogram, "couchers_servicer_pool_wait_seconds_count", method=method)
340 == pool_wait_before + 1
341 )
342 assert (
343 _get_histogram_count(
344 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="deserialize"
345 )
346 == deserialize_before + 1
347 )
348 assert (
349 _get_histogram_count(
350 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="serialize"
351 )
352 == serialize_before + 1
353 )
356def test_tracing_interceptor_perf_accounting_orm_write(db):
357 # a handler that only session.add(...)s and returns: the INSERT flushes at commit, after read_perf(), so without
358 # the interceptor's explicit flush it would be missed from the write/query counts
359 method = "/org.couchers.auth.Auth/SignupFlow"
361 def TestRpc(request, context, session):
362 session.add(APICall(method="handler-insert", duration=0.0, is_api_key=False, response_truncated=False))
363 return empty_pb2.Empty()
365 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
366 call_rpc(empty_pb2.Empty())
368 with session_scope() as session:
369 log = session.execute(select(APICall).where(APICall.method == method)).scalar_one()
370 assert log.db_query_count == 1
371 assert log.db_write_query_count == 1
374def test_tracing_interceptor_sensitive(db):
375 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "")
377 def TestRpc(request, context, session):
378 return auth_pb2.AuthReq(user="this is not secret", password="this is secret")
380 with interceptor_dummy_api(
381 TestRpc,
382 interceptors=[CouchersMiddlewareInterceptor()],
383 request_type=auth_pb2.SignupFlowReq,
384 response_type=auth_pb2.AuthReq,
385 ) as call_rpc:
386 call_rpc(
387 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed"))
388 )
390 with session_scope() as session:
391 trace = session.execute(select(APICall)).scalar_one()
392 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
393 assert not trace.status_code
394 assert not trace.user_id
395 assert not trace.traceback
396 assert trace.request is not None
397 req = auth_pb2.SignupFlowReq.FromString(trace.request)
398 assert not req.account.password
399 assert req.account.username == "not removed"
400 assert trace.response
401 res = auth_pb2.AuthReq.FromString(trace.response)
402 assert res.user == "this is not secret"
403 assert not res.password
405 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1
408def test_tracing_interceptor_sensitive_ping(db):
409 user, token = generate_user()
411 with interceptor_dummy_api(
412 API().GetUser,
413 interceptors=[CouchersMiddlewareInterceptor()],
414 request_type=api_pb2.GetUserReq,
415 response_type=api_pb2.User,
416 service_name="org.couchers.api.core.API",
417 method_name="GetUser",
418 ) as call_rpc:
419 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),))
422def test_tracing_interceptor_exception(db):
423 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "")
425 def TestRpc(request, context, session):
426 raise Exception("Some error message")
428 with interceptor_dummy_api(
429 TestRpc,
430 interceptors=[CouchersMiddlewareInterceptor()],
431 request_type=auth_pb2.SignupAccount,
432 response_type=auth_pb2.AuthReq,
433 ) as call_rpc:
434 with pytest.raises(Exception, match="Some error message"):
435 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
437 with session_scope() as session:
438 trace = session.execute(select(APICall)).scalar_one()
439 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
440 assert not trace.status_code
441 assert not trace.user_id
442 assert trace.traceback
443 assert "Some error message" in trace.traceback
444 assert trace.request is not None
445 req = auth_pb2.SignupAccount.FromString(trace.request)
446 assert not req.password
447 assert req.username == "not removed"
448 assert not trace.response
450 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1
453def test_setup_phase_exception_observed(db):
454 method = "/org.couchers.auth.Auth/SignupFlow"
455 val = _get_setup_errors_value(method, "ValueError")
457 def TestRpc(request, context, session):
458 return empty_pb2.Empty()
460 with (
461 patch("couchers.interceptors.LocalizationContext", side_effect=ValueError("expected only letters")),
462 patch("couchers.interceptors.sentry_sdk") as mock_sentry,
463 interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc,
464 ):
465 with pytest.raises(grpc.RpcError) as e:
466 call_rpc(empty_pb2.Empty())
467 assert e.value.code() == grpc.StatusCode.INTERNAL
468 assert e.value.details() == UNKNOWN_ERROR_MESSAGE
469 mock_sentry.capture_exception.assert_called_once()
471 assert _get_setup_errors_value(method, "ValueError") == val + 1
474def test_tracing_interceptor_abort(db):
475 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
477 def TestRpc(request, context, session):
478 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort")
480 with interceptor_dummy_api(
481 TestRpc,
482 interceptors=[CouchersMiddlewareInterceptor()],
483 request_type=auth_pb2.SignupAccount,
484 response_type=auth_pb2.AuthReq,
485 ) as call_rpc:
486 with pytest.raises(Exception, match="now a grpc abort"):
487 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed"))
489 with session_scope() as session:
490 trace = session.execute(select(APICall)).scalar_one()
491 assert trace.method == "/org.couchers.auth.Auth/SignupFlow"
492 assert trace.status_code == "FAILED_PRECONDITION"
493 assert not trace.user_id
494 assert trace.traceback
495 assert "now a grpc abort" in trace.traceback
496 assert trace.request is not None
497 req = auth_pb2.SignupAccount.FromString(trace.request)
498 assert not req.password
499 assert req.username == "not removed"
500 assert not trace.response
502 assert (
503 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION")
504 == val + 1
505 )
508def cookie_auth(token: str) -> tuple[str, str]:
509 return "cookie", f"couchers-sesh={token}"
512def api_auth(token: str) -> tuple[str, str]:
513 return "authorization", f"Bearer {token}"
516def test_auth_interceptor(db):
517 super_user, super_token = generate_user(is_superuser=True)
518 user, token = generate_user()
519 deleted_user, deleted_token = generate_user(delete_user=True)
521 with real_admin_session(super_token) as api:
522 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
524 with session_scope() as session:
525 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
527 account = Account()
529 rpc_def = {
530 "rpc": account.GetAccountInfo,
531 "service_name": "org.couchers.api.account.Account",
532 "method_name": "GetAccountInfo",
533 "interceptors": [CouchersMiddlewareInterceptor()],
534 "request_type": empty_pb2.Empty,
535 "response_type": account_pb2.GetAccountInfoRes,
536 }
538 # no creds, no-go for secure APIs
539 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
540 with pytest.raises(grpc.RpcError) as e:
541 call_rpc(empty_pb2.Empty())
542 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
543 assert e.value.details() == "Unauthorized"
545 # can auth with cookie
546 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
547 res1 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
548 assert res1.username == user.username
550 with session_scope() as session:
551 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
552 assert api_calls == 1
554 # can't auth with a wrong cookie
555 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
556 with pytest.raises(grpc.RpcError) as e:
557 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(random_hex(32)),))
558 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
559 assert e.value.details() == "Unauthorized"
561 # can auth with an api key
562 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
563 res2 = call_rpc(empty_pb2.Empty(), metadata=(api_auth(api_key),))
564 assert res2.username == user.username
566 with session_scope() as session:
567 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one()
568 assert api_calls == 2
570 # can't auth with a wrong api key
571 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
572 with pytest.raises(grpc.RpcError) as e:
573 call_rpc(empty_pb2.Empty(), metadata=(api_auth(random_hex(32)),))
574 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
575 assert e.value.details() == "Unauthorized"
577 # can auth with grpc helper (they do the same as above)
578 comp_creds = grpc.composite_channel_credentials(
579 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key)
580 )
581 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc:
582 res3 = call_rpc(empty_pb2.Empty())
583 assert res3.username == user.username
585 # can't auth with both
586 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
587 with pytest.raises(grpc.RpcError) as e:
588 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), api_auth(api_key)))
589 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
590 assert e.value.details() == 'Both "cookie" and "authorization" in request'
592 # malformed bearer
593 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
594 with pytest.raises(grpc.RpcError) as e:
595 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),))
596 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
597 assert e.value.details() == "Unauthorized"
599 # Invisible (deleted) user
600 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
601 with pytest.raises(grpc.RpcError) as e:
602 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(deleted_token),))
603 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
604 assert e.value.details() == "Unauthorized"
606 # Invalid (expired) session
607 long_ago = now() - timedelta(weeks=100)
608 with session_scope() as session:
609 session.execute(update(UserSession).values(last_seen=long_ago).where(UserSession.token == token))
611 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
612 with pytest.raises(grpc.RpcError) as e:
613 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
614 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
615 assert e.value.details() == "Unauthorized"
617 # API key token, but session is for session cookie (probably impossible, but...)
618 with session_scope() as session:
619 session.execute(update(UserSession).values(last_seen=now(), is_api_key=True).where(UserSession.token == token))
621 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
622 with pytest.raises(grpc.RpcError) as e:
623 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
624 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED
625 assert e.value.details() == "Unauthorized"
627 # Check that metadata are updated
628 six_minutes_ago = now() - timedelta(minutes=6)
629 with session_scope() as session:
630 # Return the session to normal
631 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
632 user_session.is_api_key = False
633 api_calls = user_session.api_calls
635 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
636 res4 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
637 assert res4.username == user.username
639 with session_scope() as session:
640 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one()
641 assert user_session.api_calls == api_calls + 1
642 assert user_session.last_seen > now() - timedelta(seconds=1)
644 # Simulate user inactivity, so last_active is updated on the next api call.
645 session.execute(update(User).values(last_active=six_minutes_ago).where(User.id == user.id))
647 # Check that last_active is updated if it wasn't updated in a while.
648 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
649 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
651 with session_scope() as session:
652 last_active = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
653 assert last_active > now() - timedelta(seconds=1)
655 # Check that last_active is untouched (since it was already updated recently)
656 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
657 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
659 with session_scope() as session:
660 last_active_2 = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one()
661 assert last_active_2 == last_active
663 # Check that activity is split by IP.
664 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
665 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), ("x-couchers-real-ip", "1.1.1.1")))
667 with session_scope() as session:
668 api_calls = session.execute(
669 select(UserActivity.api_calls).where(UserActivity.ip_address == "1.1.1.1")
670 ).scalar_one()
671 assert api_calls == 1
673 # Check that activity is split in time bins.
674 # Update all UserActivity to be in the far past so that a new row is inserted on the next request.
675 with session_scope() as session:
676 session.execute(update(UserActivity).values(period=long_ago).where(UserActivity.user_id == user.id))
678 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
679 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),))
681 with session_scope() as session:
682 api_calls = session.execute(
683 select(UserActivity.api_calls)
684 .where(UserActivity.user_id == user.id)
685 .order_by(UserActivity.id.desc())
686 .limit(1)
687 ).scalar_one()
688 assert api_calls == 1
691def test_tracing_interceptor_auth_cookies(db):
692 user, token = generate_user()
694 account = Account()
696 rpc_def = {
697 "rpc": account.GetAccountInfo,
698 "service_name": "org.couchers.api.account.Account",
699 "method_name": "GetAccountInfo",
700 "interceptors": [CouchersMiddlewareInterceptor()],
701 "request_type": empty_pb2.Empty,
702 "response_type": account_pb2.GetAccountInfoRes,
703 }
705 # with cookies
706 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
707 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),))
708 assert res1.username == user.username
710 with session_scope() as session:
711 trace = session.execute(select(APICall)).scalar_one()
712 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
713 assert not trace.status_code
714 assert trace.user_id == user.id
715 assert not trace.is_api_key
716 assert trace.request is not None
717 assert len(trace.request) == 0
718 assert not trace.traceback
721def test_tracing_interceptor_auth_api_key(db):
722 super_user, super_token = generate_user(is_superuser=True)
723 user, token = generate_user()
725 with real_admin_session(super_token) as api:
726 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username))
728 with session_scope() as session:
729 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one()
731 account = Account()
733 rpc_def = {
734 "rpc": account.GetAccountInfo,
735 "service_name": "org.couchers.api.account.Account",
736 "method_name": "GetAccountInfo",
737 "interceptors": [CouchersMiddlewareInterceptor()],
738 "request_type": empty_pb2.Empty,
739 "response_type": account_pb2.GetAccountInfoRes,
740 }
742 # with api key
743 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
744 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),))
745 assert res1.username == user.username
747 with session_scope() as session:
748 trace = session.execute(
749 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo")
750 ).scalar_one()
751 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo"
752 assert not trace.status_code
753 assert trace.user_id == user.id
754 assert trace.is_api_key
755 assert trace.request is not None
756 assert len(trace.request) == 0
757 assert not trace.traceback
760def test_auth_levels(db):
761 def TestRpc(request, context, session):
762 return empty_pb2.Empty()
764 def gen_args(service, method):
765 return {
766 "rpc": TestRpc,
767 "service_name": service,
768 "method_name": method,
769 "interceptors": [CouchersMiddlewareInterceptor()],
770 "request_type": empty_pb2.Empty,
771 "response_type": empty_pb2.Empty,
772 }
774 # superuser (note: superusers are automatically editors due to DB constraint)
775 _, super_token = generate_user(is_superuser=True)
776 # editor user
777 _, editor_token = generate_user(is_editor=True)
778 # normal user
779 _, normal_token = generate_user()
780 # jailed user
781 _, jailed_token = generate_user(accepted_tos=0)
782 # open user
783 open_token = ""
785 # pick some rpcs here with the right auth levels
786 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService")
787 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo")
788 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo")
789 editor_args = gen_args("org.couchers.editor.Editor", "CreateCommunity")
790 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails")
792 # pairs to check
793 checks = [
794 # name, args, token, works?, code, message
795 # open token only works on open servicers
796 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
797 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
798 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
799 ("open x editor", open_token, editor_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
800 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
801 # jailed works on jailed and open
802 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
803 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
804 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"),
805 ("jailed x editor", jailed_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
806 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
807 # normal works on all but editor and admin
808 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
809 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
810 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
811 ("normal x editor", normal_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
812 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
813 # editor works on all but admin
814 ("editor x open", editor_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
815 ("editor x jailed", editor_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
816 ("editor x secure", editor_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
817 ("editor x editor", editor_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
818 ("editor x admin", editor_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
819 # superuser works on all
820 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
821 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
822 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"),
823 ("super x editor", super_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
824 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"),
825 ]
827 for name, token, args, should_work, code, message in checks:
828 print(f"Testing (token x args) = ({name}), {should_work=}")
829 metadata = (("cookie", f"couchers-sesh={token}"),)
830 with interceptor_dummy_api(**args) as call_rpc:
831 if should_work:
832 call_rpc(empty_pb2.Empty(), metadata=metadata)
833 else:
834 with pytest.raises(grpc.RpcError) as err:
835 call_rpc(empty_pb2.Empty(), metadata=metadata)
836 assert err.value.code() == code
837 assert err.value.details() == message
839 # a non-existent RPC
840 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing")
842 with interceptor_dummy_api(**nonexistent) as call_rpc:
843 with pytest.raises(grpc.RpcError) as err:
844 call_rpc(empty_pb2.Empty())
845 assert err.value.code() == grpc.StatusCode.UNIMPLEMENTED
846 assert err.value.details() == "API call does not exist. Please refresh and try again."
848 # an RPC without a service level
849 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation")
851 with interceptor_dummy_api(**invalid_args) as call_rpc:
852 with pytest.raises(grpc.RpcError) as err:
853 call_rpc(empty_pb2.Empty())
854 assert err.value.code() == grpc.StatusCode.INTERNAL
855 assert err.value.details() == "Internal authentication error."
858def test_parse_headers_with_session_cookie():
859 headers = {"cookie": "couchers-sesh=abc123; other-cookie=value"}
860 result = parse_headers(headers)
861 assert result.token == "abc123"
862 assert result.is_api_key is False
865def test_parse_headers_with_authorization_header():
866 headers = {"authorization": "Bearer abc123"}
867 result = parse_headers(headers)
868 assert result.token == "abc123"
869 assert result.is_api_key is True
872def test_parse_headers_with_both_cookie_and_authorization():
873 headers = {"cookie": "couchers-sesh=abc123", "authorization": "Bearer xyz789"}
874 with pytest.raises(BadHeaders, match="Both cookies and authorization are present in headers"):
875 parse_headers(headers)
878def test_parse_headers_with_neither_cookie_nor_authorization():
879 result = parse_headers({})
880 assert result.token is None
881 assert result.is_api_key is False
884def test_parse_headers_with_all_optional_headers():
885 headers = {
886 "cookie": "couchers-sesh=abc123; couchers-user-id=42; NEXT_LOCALE=en",
887 "x-couchers-real-ip": "192.168.1.1",
888 "user-agent": "TestAgent/1.0",
889 }
890 result = parse_headers(headers)
891 assert result.token == "abc123"
892 assert result.is_api_key is False
893 assert result.ip_address == "192.168.1.1"
894 assert result.user_agent == "TestAgent/1.0"
895 assert result.ui_lang == "en"
896 assert result.user_id == "42"
899def test_parse_headers_with_bytes_ip_address():
900 headers: dict[str, str | bytes] = {
901 "cookie": "couchers-sesh=abc123",
902 "x-couchers-real-ip": b"192.168.1.1",
903 }
904 result = parse_headers(headers)
905 assert result.ip_address is None
908def test_parse_headers_with_bytes_user_agent():
909 headers: dict[str, str | bytes] = {
910 "cookie": "couchers-sesh=abc123",
911 "user-agent": b"TestAgent/1.0",
912 }
913 result = parse_headers(headers)
914 assert result.user_agent is None
917def test_parse_headers_malformed_authorization():
918 headers = {"authorization": "bearer abc123"}
919 result = parse_headers(headers)
920 assert result.token is None
921 assert result.is_api_key is True
924def test_find_auth_level_with_valid_service():
925 pool = get_descriptor_pool()
927 result = find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
928 assert result == annotations_pb2.AUTH_LEVEL_SECURE
931def test_find_auth_level_with_nonexistent_service():
932 pool = get_descriptor_pool()
934 with pytest.raises(AbortError) as exc:
935 find_auth_level(pool, "/org.couchers.nonexistent.Service/Method")
936 assert exc.value.msg == NONEXISTENT_API_CALL_ERROR_MESSAGE
937 assert exc.value.code == grpc.StatusCode.UNIMPLEMENTED
940def test_find_auth_level_with_unknown_auth_level():
941 pool = Mock(spec=DescriptorPool)
942 service_desc = Mock(spec=ServiceDescriptor)
943 service_options = Mock()
944 service_options.Extensions = {annotations_pb2.auth_level: annotations_pb2.AUTH_LEVEL_UNKNOWN}
945 service_desc.GetOptions.return_value = service_options
946 pool.FindServiceByName.return_value = service_desc
948 with pytest.raises(AbortError) as exc:
949 find_auth_level(pool, "/org.couchers.api.core.API/GetUser")
950 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
951 assert exc.value.code == grpc.StatusCode.INTERNAL
954def test_validate_auth_level_with_unknown():
955 with pytest.raises(AbortError) as exc:
956 validate_auth_level(annotations_pb2.AUTH_LEVEL_UNKNOWN)
957 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE
958 assert exc.value.code == grpc.StatusCode.INTERNAL
961def test_validate_auth_level_with_open():
962 validate_auth_level(annotations_pb2.AUTH_LEVEL_OPEN)
965def test_validate_auth_level_with_jailed():
966 validate_auth_level(annotations_pb2.AUTH_LEVEL_JAILED)
969def test_validate_auth_level_with_secure():
970 validate_auth_level(annotations_pb2.AUTH_LEVEL_SECURE)
973def test_validate_auth_level_with_editor():
974 validate_auth_level(annotations_pb2.AUTH_LEVEL_EDITOR)
977def test_validate_auth_level_with_admin():
978 validate_auth_level(annotations_pb2.AUTH_LEVEL_ADMIN)
981def test_check_auth_open_service_without_auth():
982 check_permissions(None, annotations_pb2.AUTH_LEVEL_OPEN)
985def test_check_auth_open_service_with_auth():
986 auth_info = UserAuthInfo(
987 user_id=1,
988 is_jailed=False,
989 is_editor=False,
990 is_superuser=False,
991 token_expiry=now(),
992 ui_language_preference="en",
993 timezone="Etc/UTC",
994 token="abc123",
995 is_api_key=False,
996 )
997 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_OPEN)
1000def test_check_auth_secure_service_without_auth():
1001 with pytest.raises(AbortError):
1002 check_permissions(None, annotations_pb2.AUTH_LEVEL_SECURE)
1005def test_check_auth_secure_service_with_normal_auth():
1006 auth_info = UserAuthInfo(
1007 user_id=1,
1008 is_jailed=False,
1009 is_editor=False,
1010 is_superuser=False,
1011 token_expiry=now(),
1012 ui_language_preference="en",
1013 timezone="Etc/UTC",
1014 token="abc123",
1015 is_api_key=False,
1016 )
1017 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
1020def test_check_auth_secure_service_with_jailed_user():
1021 auth_info = UserAuthInfo(
1022 user_id=1,
1023 is_jailed=True,
1024 is_editor=False,
1025 is_superuser=False,
1026 token_expiry=now(),
1027 ui_language_preference="en",
1028 timezone="Etc/UTC",
1029 token="abc123",
1030 is_api_key=False,
1031 )
1032 with pytest.raises(AbortError):
1033 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE)
1036def test_check_auth_jailed_service_with_jailed_user():
1037 auth_info = UserAuthInfo(
1038 user_id=1,
1039 is_jailed=True,
1040 is_editor=False,
1041 is_superuser=False,
1042 token_expiry=now(),
1043 ui_language_preference="en",
1044 timezone="Etc/UTC",
1045 token="abc123",
1046 is_api_key=False,
1047 )
1048 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_JAILED)
1051def test_check_auth_jailed_service_without_auth():
1052 with pytest.raises(AbortError):
1053 check_permissions(None, annotations_pb2.AUTH_LEVEL_JAILED)
1056def test_check_auth_editor_service_without_editor():
1057 auth_info = UserAuthInfo(
1058 user_id=1,
1059 is_jailed=False,
1060 is_editor=False,
1061 is_superuser=False,
1062 token_expiry=now(),
1063 ui_language_preference="en",
1064 timezone="Etc/UTC",
1065 token="abc123",
1066 is_api_key=False,
1067 )
1068 with pytest.raises(AbortError):
1069 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
1072def test_check_auth_editor_service_with_editor():
1073 auth_info = UserAuthInfo(
1074 user_id=1,
1075 is_jailed=False,
1076 is_editor=True,
1077 is_superuser=False,
1078 token_expiry=now(),
1079 ui_language_preference="en",
1080 timezone="Etc/UTC",
1081 token="abc123",
1082 is_api_key=False,
1083 )
1084 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR)
1087def test_check_auth_admin_service_without_superuser():
1088 auth_info = UserAuthInfo(
1089 user_id=1,
1090 is_jailed=False,
1091 is_editor=True,
1092 is_superuser=False,
1093 token_expiry=now(),
1094 ui_language_preference="en",
1095 timezone="Etc/UTC",
1096 token="abc123",
1097 is_api_key=False,
1098 )
1099 with pytest.raises(AbortError):
1100 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
1103def test_check_auth_admin_service_with_superuser():
1104 auth_info = UserAuthInfo(
1105 user_id=1,
1106 is_jailed=False,
1107 is_editor=True,
1108 is_superuser=True,
1109 token_expiry=now(),
1110 ui_language_preference="en",
1111 timezone="Etc/UTC",
1112 token="abc123",
1113 is_api_key=False,
1114 )
1115 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN)
1118def test_check_auth_admin_service_without_auth():
1119 with pytest.raises(AbortError):
1120 check_permissions(None, annotations_pb2.AUTH_LEVEL_ADMIN)
1123def test_parse_sofa_cookie_valid():
1124 sofa_value, cookie_string = generate_sofa_cookie()
1125 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1127 headers = {"cookie": f"sofa={cookie_value}"}
1128 result = parse_sofa_cookie(headers)
1129 assert result == sofa_value
1132def test_parse_sofa_cookie_missing():
1133 headers = {"cookie": "other-cookie=value"}
1134 result = parse_sofa_cookie(headers)
1135 assert result is None
1138def test_parse_sofa_cookie_no_cookies():
1139 headers: dict[str, str] = {}
1140 result = parse_sofa_cookie(headers)
1141 assert result is None
1144def test_parse_sofa_cookie_invalid_base64():
1145 headers = {"cookie": "sofa=not-valid-base64!!!"}
1146 result = parse_sofa_cookie(headers)
1147 assert result is None
1150def test_parse_sofa_cookie_invalid_encryption():
1151 headers = {"cookie": f"sofa={b64encode(b'invalid encrypted data')}"}
1152 result = parse_sofa_cookie(headers)
1153 assert result is None
1156def test_parse_sofa_cookie_invalid_proto():
1157 encrypted = simple_encrypt("sofa_cookie", b"not a valid proto")
1158 headers = {"cookie": f"sofa={b64encode(encrypted)}"}
1159 result = parse_sofa_cookie(headers)
1160 assert result is not None or result is None
1163def test_generate_sofa_cookie():
1164 sofa_value, cookie_string = generate_sofa_cookie()
1166 assert sofa_value
1167 assert isinstance(sofa_value, str)
1168 assert len(sofa_value) > 20
1170 assert "sofa=" in cookie_string
1171 assert "expires=" in cookie_string.lower()
1173 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1174 headers = {"cookie": f"sofa={cookie_value}"}
1175 parsed_value = parse_sofa_cookie(headers)
1176 assert parsed_value == sofa_value
1179def test_parse_headers_with_sofa_cookie():
1180 sofa_value, cookie_string = generate_sofa_cookie()
1181 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1183 headers = {
1184 "cookie": f"couchers-sesh=abc123; sofa={cookie_value}",
1185 }
1186 result = parse_headers(headers)
1187 assert result.token == "abc123"
1188 assert result.sofa == sofa_value
1191def test_parse_headers_without_sofa_cookie():
1192 headers = {
1193 "cookie": "couchers-sesh=abc123",
1194 }
1195 result = parse_headers(headers)
1196 assert result.token == "abc123"
1197 assert result.sofa is None
1200def test_sofa_cookie_logged_new(db):
1201 def TestRpc(request, context, session):
1202 return empty_pb2.Empty()
1204 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1205 call_rpc(empty_pb2.Empty())
1207 with session_scope() as session:
1208 trace = session.execute(select(APICall)).scalar_one()
1209 assert trace.sofa is not None
1210 assert len(trace.sofa) > 20
1213def test_sofa_cookie_logged_existing(db):
1214 sofa_value, cookie_string = generate_sofa_cookie()
1215 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1217 def TestRpc(request, context, session):
1218 return empty_pb2.Empty()
1220 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1221 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1223 with session_scope() as session:
1224 trace = session.execute(select(APICall)).scalar_one()
1225 assert trace.sofa == sofa_value
1228def test_sofa_cookie_logged_invalid_generates_new(db):
1229 def TestRpc(request, context, session):
1230 return empty_pb2.Empty()
1232 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1233 call_rpc(empty_pb2.Empty(), metadata=(("cookie", "sofa=invalid-cookie-value"),))
1235 with session_scope() as session:
1236 trace = session.execute(select(APICall)).scalar_one()
1237 assert trace.sofa is not None
1238 assert trace.sofa != "invalid-cookie-value"
1239 assert len(trace.sofa) > 20
1242def test_sofa_cookie_with_authenticated_user(db):
1243 user, token = generate_user()
1244 sofa_value, cookie_string = generate_sofa_cookie()
1245 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1247 account = Account()
1249 rpc_def = {
1250 "rpc": account.GetAccountInfo,
1251 "service_name": "org.couchers.api.account.Account",
1252 "method_name": "GetAccountInfo",
1253 "interceptors": [CouchersMiddlewareInterceptor()],
1254 "request_type": empty_pb2.Empty,
1255 "response_type": account_pb2.GetAccountInfoRes,
1256 }
1258 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc:
1259 res = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}; sofa={cookie_value}"),))
1260 assert res.username == user.username
1262 with session_scope() as session:
1263 trace = session.execute(select(APICall)).scalar_one()
1264 assert trace.user_id == user.id
1265 assert trace.sofa == sofa_value
1268def test_sofa_cookie_persists_on_exception(db):
1269 sofa_value, cookie_string = generate_sofa_cookie()
1270 cookie_value = cookie_string.split("=", 1)[1].split(";")[0]
1272 def TestRpc(request, context, session):
1273 raise Exception("Test error")
1275 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc:
1276 with pytest.raises(Exception, match="Test error"):
1277 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),))
1279 with session_scope() as session:
1280 trace = session.execute(select(APICall)).scalar_one()
1281 assert trace.sofa == sofa_value
1282 assert trace.traceback is not None
1283 assert "Test error" in trace.traceback