Coverage for app/backend/src/tests/test_interceptors.py: 99%

657 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 09:29 +0000

1from collections.abc import Callable, Generator 

2from concurrent import futures 

3from contextlib import contextmanager 

4from datetime import timedelta 

5from typing import Any 

6from unittest.mock import Mock, patch 

7 

8import grpc 

9import pytest 

10from google.protobuf import empty_pb2 

11from google.protobuf.descriptor import ServiceDescriptor 

12from google.protobuf.descriptor_pool import DescriptorPool 

13from sqlalchemy import select, text, update 

14 

15from couchers.constants import ( 

16 MISSING_AUTH_LEVEL_ERROR_MESSAGE, 

17 NONEXISTENT_API_CALL_ERROR_MESSAGE, 

18 UNKNOWN_ERROR_MESSAGE, 

19) 

20from couchers.crypto import b64encode, random_hex, simple_encrypt 

21from couchers.db import session_scope 

22from couchers.descriptor_pool import get_descriptor_pool 

23from couchers.interceptors import ( 

24 AbortError, 

25 BadHeaders, 

26 CouchersMiddlewareInterceptor, 

27 ErrorSanitizationInterceptor, 

28 UserAuthInfo, 

29 check_permissions, 

30 find_auth_level, 

31 parse_headers, 

32 validate_auth_level, 

33) 

34from couchers.metrics import ( 

35 api_calls_counter, 

36 servicer_db_query_count_histogram, 

37 servicer_duration_histogram, 

38 servicer_pool_wait_histogram, 

39 servicer_serde_histogram, 

40 servicer_setup_cpu_time_histogram, 

41 servicer_setup_db_time_histogram, 

42 servicer_setup_errors_counter, 

43) 

44from couchers.models import APICall, ClientPlatform, User, UserActivity, UserSession 

45from couchers.proto import account_pb2, admin_pb2, annotations_pb2, api_pb2, auth_pb2 

46from couchers.servicers.account import Account 

47from couchers.servicers.api import API 

48from couchers.utils import generate_sofa_cookie, now, parse_sofa_cookie 

49from tests.fixtures.db import generate_user 

50from tests.fixtures.sessions import real_admin_session 

51 

52 

53@pytest.fixture(autouse=True) 

54def _(testconfig): 

55 pass 

56 

57 

58@contextmanager 

59def interceptor_dummy_api( 

60 rpc, 

61 interceptors, 

62 service_name="org.couchers.auth.Auth", 

63 method_name="SignupFlow", 

64 request_type=empty_pb2.Empty, 

65 response_type=empty_pb2.Empty, 

66 creds=None, 

67) -> Generator[Callable[..., Any]]: 

68 with futures.ThreadPoolExecutor(1) as executor: 

69 server = grpc.server(executor, interceptors=interceptors) 

70 port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) 

71 

72 # manually add the handler 

73 rpc_method_handlers = { 

74 method_name: grpc.unary_unary_rpc_method_handler( 

75 rpc, 

76 request_deserializer=request_type.FromString, 

77 response_serializer=response_type.SerializeToString, 

78 ) 

79 } 

80 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) 

81 server.add_generic_rpc_handlers((generic_handler,)) 

82 server.start() 

83 

84 try: 

85 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: 

86 yield channel.unary_unary( 

87 f"/{service_name}/{method_name}", 

88 request_serializer=request_type.SerializeToString, 

89 response_deserializer=response_type.FromString, 

90 ) 

91 finally: 

92 server.stop(None).wait() 

93 

94 

95def _get_histogram_labels_value(method, logged_in, exception, code): 

96 metrics = servicer_duration_histogram.collect() 

97 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0] 

98 histogram_counts = [ 

99 s 

100 for s in servicer_histogram.samples 

101 if s.name == "couchers_servicer_duration_seconds_count" 

102 and s.labels["method"] == method 

103 and s.labels["logged_in"] == logged_in 

104 and s.labels["code"] == code 

105 and s.labels["exception"] == exception 

106 ] 

107 if len(histogram_counts) == 0: 

108 return 0 

109 return histogram_counts[0].value 

110 

111 

112def _get_setup_errors_value(method, exception): 

113 metrics = servicer_setup_errors_counter.collect() 

114 counter = [m for m in metrics if m.name == "couchers_servicer_setup_errors"][0] 

115 samples = [ 

116 s 

117 for s in counter.samples 

118 if s.name == "couchers_servicer_setup_errors_total" 

119 and s.labels["method"] == method 

120 and s.labels["exception"] == exception 

121 ] 

122 if len(samples) == 0: 

123 return 0 

124 return samples[0].value 

125 

126 

127def test_logging_interceptor_ok(): 

128 def TestRpc(request, context): 

129 return empty_pb2.Empty() 

130 

131 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

132 call_rpc(empty_pb2.Empty()) 

133 

134 

135def test_logging_interceptor_all_ignored(): 

136 # error codes that should not be touched by the interceptor 

137 pass_through_status_codes = [ 

138 # we can't abort with OK 

139 # grpc.StatusCode.OK, 

140 grpc.StatusCode.CANCELLED, 

141 grpc.StatusCode.UNKNOWN, 

142 grpc.StatusCode.INVALID_ARGUMENT, 

143 grpc.StatusCode.DEADLINE_EXCEEDED, 

144 grpc.StatusCode.NOT_FOUND, 

145 grpc.StatusCode.ALREADY_EXISTS, 

146 grpc.StatusCode.PERMISSION_DENIED, 

147 grpc.StatusCode.UNAUTHENTICATED, 

148 grpc.StatusCode.RESOURCE_EXHAUSTED, 

149 grpc.StatusCode.FAILED_PRECONDITION, 

150 grpc.StatusCode.ABORTED, 

151 grpc.StatusCode.OUT_OF_RANGE, 

152 grpc.StatusCode.UNIMPLEMENTED, 

153 grpc.StatusCode.INTERNAL, 

154 grpc.StatusCode.UNAVAILABLE, 

155 grpc.StatusCode.DATA_LOSS, 

156 ] 

157 

158 for status_code in pass_through_status_codes: 

159 message = random_hex() 

160 

161 def TestRpc(request, context): 

162 context.abort(status_code, message) # noqa: B023 

163 

164 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

165 with pytest.raises(grpc.RpcError) as e: 

166 call_rpc(empty_pb2.Empty()) 

167 assert e.value.code() == status_code 

168 assert e.value.details() == message 

169 

170 

171def test_logging_interceptor_assertion(): 

172 def TestRpc(request, context): 

173 raise AssertionError() 

174 

175 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

176 with pytest.raises(grpc.RpcError) as e: 

177 call_rpc(empty_pb2.Empty()) 

178 assert e.value.code() == grpc.StatusCode.INTERNAL 

179 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!" 

180 

181 

182def test_logging_interceptor_div0(): 

183 def TestRpc(request, context): 

184 1 / 0 # noqa: B018 

185 

186 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

187 with pytest.raises(grpc.RpcError) as e: 

188 call_rpc(empty_pb2.Empty()) 

189 assert e.value.code() == grpc.StatusCode.INTERNAL 

190 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!" 

191 

192 

193def test_logging_interceptor_raise(): 

194 def TestRpc(request, context): 

195 raise Exception() 

196 

197 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

198 with pytest.raises(grpc.RpcError) as e: 

199 call_rpc(empty_pb2.Empty()) 

200 assert e.value.code() == grpc.StatusCode.INTERNAL 

201 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!" 

202 

203 

204def test_logging_interceptor_raise_custom(): 

205 class _TestingException(Exception): 

206 pass 

207 

208 def TestRpc(request, context): 

209 raise _TestingException("This is a custom exception") 

210 

211 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

212 with pytest.raises(grpc.RpcError) as e: 

213 call_rpc(empty_pb2.Empty()) 

214 assert e.value.code() == grpc.StatusCode.INTERNAL 

215 assert e.value.details() == "An unknown backend error occurred. Please consider filing a bug!" 

216 

217 

218def test_tracing_interceptor_ok_open(db): 

219 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") 

220 

221 def TestRpc(request, context, session): 

222 return empty_pb2.Empty() 

223 

224 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

225 call_rpc(empty_pb2.Empty()) 

226 

227 with session_scope() as session: 

228 trace = session.execute(select(APICall)).scalar_one() 

229 assert trace.method == "/org.couchers.auth.Auth/SignupFlow" 

230 assert not trace.status_code 

231 assert not trace.user_id 

232 assert trace.request is not None 

233 assert len(trace.request) == 0 

234 assert trace.response is not None 

235 assert len(trace.response) == 0 

236 assert not trace.traceback 

237 

238 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1 

239 

240 

241def _get_db_query_count_histogram(method): 

242 return sum( 

243 s.value 

244 for m in servicer_db_query_count_histogram.collect() 

245 for s in m.samples 

246 if s.name == "couchers_servicer_db_query_count_count" and s.labels.get("method") == method 

247 ) 

248 

249 

250def _get_api_call_count(method, platform): 

251 return sum( 

252 s.value 

253 for m in api_calls_counter.collect() 

254 for s in m.samples 

255 if s.name == "couchers_api_calls_total" 

256 and s.labels.get("method") == method 

257 and s.labels.get("platform") == platform 

258 ) 

259 

260 

261def test_tracing_interceptor_perf_accounting(db): 

262 method = "/org.couchers.auth.Auth/SignupFlow" 

263 hist_count_before = _get_db_query_count_histogram(method) 

264 api_call_count_before = _get_api_call_count(method, "web_mobile") 

265 

266 # handler runs a known number of statements: three reads and one compiled write. The write matches zero rows so 

267 # it's side-effect free. 

268 def TestRpc(request, context, session): 

269 for _ in range(3): 

270 session.execute(text("SELECT 1")) 

271 session.execute(update(APICall).where(APICall.id == -1).values(method="x")) 

272 return empty_pb2.Empty() 

273 

274 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

275 call_rpc(empty_pb2.Empty(), metadata=(("x-couchers-client-platform", "web_mobile"),)) 

276 

277 with session_scope() as session: 

278 trace = session.execute(select(APICall)).scalar_one() 

279 assert trace.db_query_count == 4 

280 assert trace.db_write_query_count == 1 

281 assert trace.db_time_ms is not None and trace.db_time_ms >= 0 

282 assert trace.cpu_ms is not None and trace.cpu_ms >= 0 

283 # the handler's DB work can't exceed the whole-request wall time 

284 assert trace.db_time_ms <= trace.duration 

285 assert trace.client_platform == ClientPlatform.web_mobile 

286 

287 # the call was also observed into the Prometheus per-request resource histograms and the per-platform call counter 

288 assert _get_db_query_count_histogram(method) == hist_count_before + 1 

289 assert _get_api_call_count(method, "web_mobile") == api_call_count_before + 1 

290 

291 

292def _get_histogram_count(histogram, count_name, **labels): 

293 return sum( 

294 s.value 

295 for m in histogram.collect() 

296 for s in m.samples 

297 if s.name == count_name and all(s.labels.get(k) == v for k, v in labels.items()) 

298 ) 

299 

300 

301def test_tracing_interceptor_phase_histograms(db): 

302 # setup db/cpu, pool-wait, and de/serialization are each observed once per call into their own histogram 

303 method = "/org.couchers.auth.Auth/SignupFlow" 

304 setup_db_before = _get_histogram_count( 

305 servicer_setup_db_time_histogram, "couchers_servicer_setup_db_time_seconds_count", method=method 

306 ) 

307 setup_cpu_before = _get_histogram_count( 

308 servicer_setup_cpu_time_histogram, "couchers_servicer_setup_cpu_seconds_count", method=method 

309 ) 

310 pool_wait_before = _get_histogram_count( 

311 servicer_pool_wait_histogram, "couchers_servicer_pool_wait_seconds_count", method=method 

312 ) 

313 deserialize_before = _get_histogram_count( 

314 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="deserialize" 

315 ) 

316 serialize_before = _get_histogram_count( 

317 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="serialize" 

318 ) 

319 

320 def TestRpc(request, context, session): 

321 return empty_pb2.Empty() 

322 

323 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

324 call_rpc(empty_pb2.Empty()) 

325 

326 assert ( 

327 _get_histogram_count( 

328 servicer_setup_db_time_histogram, "couchers_servicer_setup_db_time_seconds_count", method=method 

329 ) 

330 == setup_db_before + 1 

331 ) 

332 assert ( 

333 _get_histogram_count( 

334 servicer_setup_cpu_time_histogram, "couchers_servicer_setup_cpu_seconds_count", method=method 

335 ) 

336 == setup_cpu_before + 1 

337 ) 

338 assert ( 

339 _get_histogram_count(servicer_pool_wait_histogram, "couchers_servicer_pool_wait_seconds_count", method=method) 

340 == pool_wait_before + 1 

341 ) 

342 assert ( 

343 _get_histogram_count( 

344 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="deserialize" 

345 ) 

346 == deserialize_before + 1 

347 ) 

348 assert ( 

349 _get_histogram_count( 

350 servicer_serde_histogram, "couchers_servicer_serde_seconds_count", method=method, direction="serialize" 

351 ) 

352 == serialize_before + 1 

353 ) 

354 

355 

356def test_tracing_interceptor_perf_accounting_orm_write(db): 

357 # a handler that only session.add(...)s and returns: the INSERT flushes at commit, after read_perf(), so without 

358 # the interceptor's explicit flush it would be missed from the write/query counts 

359 method = "/org.couchers.auth.Auth/SignupFlow" 

360 

361 def TestRpc(request, context, session): 

362 session.add(APICall(method="handler-insert", duration=0.0, is_api_key=False, response_truncated=False)) 

363 return empty_pb2.Empty() 

364 

365 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

366 call_rpc(empty_pb2.Empty()) 

367 

368 with session_scope() as session: 

369 log = session.execute(select(APICall).where(APICall.method == method)).scalar_one() 

370 assert log.db_query_count == 1 

371 assert log.db_write_query_count == 1 

372 

373 

374def test_tracing_interceptor_sensitive(db): 

375 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") 

376 

377 def TestRpc(request, context, session): 

378 return auth_pb2.AuthReq(user="this is not secret", password="this is secret") 

379 

380 with interceptor_dummy_api( 

381 TestRpc, 

382 interceptors=[CouchersMiddlewareInterceptor()], 

383 request_type=auth_pb2.SignupFlowReq, 

384 response_type=auth_pb2.AuthReq, 

385 ) as call_rpc: 

386 call_rpc( 

387 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

388 ) 

389 

390 with session_scope() as session: 

391 trace = session.execute(select(APICall)).scalar_one() 

392 assert trace.method == "/org.couchers.auth.Auth/SignupFlow" 

393 assert not trace.status_code 

394 assert not trace.user_id 

395 assert not trace.traceback 

396 assert trace.request is not None 

397 req = auth_pb2.SignupFlowReq.FromString(trace.request) 

398 assert not req.account.password 

399 assert req.account.username == "not removed" 

400 assert trace.response 

401 res = auth_pb2.AuthReq.FromString(trace.response) 

402 assert res.user == "this is not secret" 

403 assert not res.password 

404 

405 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "", "") == val + 1 

406 

407 

408def test_tracing_interceptor_sensitive_ping(db): 

409 user, token = generate_user() 

410 

411 with interceptor_dummy_api( 

412 API().GetUser, 

413 interceptors=[CouchersMiddlewareInterceptor()], 

414 request_type=api_pb2.GetUserReq, 

415 response_type=api_pb2.User, 

416 service_name="org.couchers.api.core.API", 

417 method_name="GetUser", 

418 ) as call_rpc: 

419 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),)) 

420 

421 

422def test_tracing_interceptor_exception(db): 

423 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") 

424 

425 def TestRpc(request, context, session): 

426 raise Exception("Some error message") 

427 

428 with interceptor_dummy_api( 

429 TestRpc, 

430 interceptors=[CouchersMiddlewareInterceptor()], 

431 request_type=auth_pb2.SignupAccount, 

432 response_type=auth_pb2.AuthReq, 

433 ) as call_rpc: 

434 with pytest.raises(Exception, match="Some error message"): 

435 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

436 

437 with session_scope() as session: 

438 trace = session.execute(select(APICall)).scalar_one() 

439 assert trace.method == "/org.couchers.auth.Auth/SignupFlow" 

440 assert not trace.status_code 

441 assert not trace.user_id 

442 assert trace.traceback 

443 assert "Some error message" in trace.traceback 

444 assert trace.request is not None 

445 req = auth_pb2.SignupAccount.FromString(trace.request) 

446 assert not req.password 

447 assert req.username == "not removed" 

448 assert not trace.response 

449 

450 assert _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "") == val + 1 

451 

452 

453def test_setup_phase_exception_observed(db): 

454 method = "/org.couchers.auth.Auth/SignupFlow" 

455 val = _get_setup_errors_value(method, "ValueError") 

456 

457 def TestRpc(request, context, session): 

458 return empty_pb2.Empty() 

459 

460 with ( 

461 patch("couchers.interceptors.LocalizationContext", side_effect=ValueError("expected only letters")), 

462 patch("couchers.interceptors.sentry_sdk") as mock_sentry, 

463 interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc, 

464 ): 

465 with pytest.raises(grpc.RpcError) as e: 

466 call_rpc(empty_pb2.Empty()) 

467 assert e.value.code() == grpc.StatusCode.INTERNAL 

468 assert e.value.details() == UNKNOWN_ERROR_MESSAGE 

469 mock_sentry.capture_exception.assert_called_once() 

470 

471 assert _get_setup_errors_value(method, "ValueError") == val + 1 

472 

473 

474def test_tracing_interceptor_abort(db): 

475 val = _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION") 

476 

477 def TestRpc(request, context, session): 

478 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort") 

479 

480 with interceptor_dummy_api( 

481 TestRpc, 

482 interceptors=[CouchersMiddlewareInterceptor()], 

483 request_type=auth_pb2.SignupAccount, 

484 response_type=auth_pb2.AuthReq, 

485 ) as call_rpc: 

486 with pytest.raises(Exception, match="now a grpc abort"): 

487 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

488 

489 with session_scope() as session: 

490 trace = session.execute(select(APICall)).scalar_one() 

491 assert trace.method == "/org.couchers.auth.Auth/SignupFlow" 

492 assert trace.status_code == "FAILED_PRECONDITION" 

493 assert not trace.user_id 

494 assert trace.traceback 

495 assert "now a grpc abort" in trace.traceback 

496 assert trace.request is not None 

497 req = auth_pb2.SignupAccount.FromString(trace.request) 

498 assert not req.password 

499 assert req.username == "not removed" 

500 assert not trace.response 

501 

502 assert ( 

503 _get_histogram_labels_value("/org.couchers.auth.Auth/SignupFlow", "False", "Exception", "FAILED_PRECONDITION") 

504 == val + 1 

505 ) 

506 

507 

508def cookie_auth(token: str) -> tuple[str, str]: 

509 return "cookie", f"couchers-sesh={token}" 

510 

511 

512def api_auth(token: str) -> tuple[str, str]: 

513 return "authorization", f"Bearer {token}" 

514 

515 

516def test_auth_interceptor(db): 

517 super_user, super_token = generate_user(is_superuser=True) 

518 user, token = generate_user() 

519 deleted_user, deleted_token = generate_user(delete_user=True) 

520 

521 with real_admin_session(super_token) as api: 

522 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

523 

524 with session_scope() as session: 

525 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one() 

526 

527 account = Account() 

528 

529 rpc_def = { 

530 "rpc": account.GetAccountInfo, 

531 "service_name": "org.couchers.api.account.Account", 

532 "method_name": "GetAccountInfo", 

533 "interceptors": [CouchersMiddlewareInterceptor()], 

534 "request_type": empty_pb2.Empty, 

535 "response_type": account_pb2.GetAccountInfoRes, 

536 } 

537 

538 # no creds, no-go for secure APIs 

539 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

540 with pytest.raises(grpc.RpcError) as e: 

541 call_rpc(empty_pb2.Empty()) 

542 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

543 assert e.value.details() == "Unauthorized" 

544 

545 # can auth with cookie 

546 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

547 res1 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

548 assert res1.username == user.username 

549 

550 with session_scope() as session: 

551 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one() 

552 assert api_calls == 1 

553 

554 # can't auth with a wrong cookie 

555 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

556 with pytest.raises(grpc.RpcError) as e: 

557 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(random_hex(32)),)) 

558 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

559 assert e.value.details() == "Unauthorized" 

560 

561 # can auth with an api key 

562 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

563 res2 = call_rpc(empty_pb2.Empty(), metadata=(api_auth(api_key),)) 

564 assert res2.username == user.username 

565 

566 with session_scope() as session: 

567 api_calls = session.execute(select(UserActivity.api_calls).where(UserActivity.user_id == user.id)).scalar_one() 

568 assert api_calls == 2 

569 

570 # can't auth with a wrong api key 

571 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

572 with pytest.raises(grpc.RpcError) as e: 

573 call_rpc(empty_pb2.Empty(), metadata=(api_auth(random_hex(32)),)) 

574 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

575 assert e.value.details() == "Unauthorized" 

576 

577 # can auth with grpc helper (they do the same as above) 

578 comp_creds = grpc.composite_channel_credentials( 

579 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key) 

580 ) 

581 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc: 

582 res3 = call_rpc(empty_pb2.Empty()) 

583 assert res3.username == user.username 

584 

585 # can't auth with both 

586 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

587 with pytest.raises(grpc.RpcError) as e: 

588 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), api_auth(api_key))) 

589 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

590 assert e.value.details() == 'Both "cookie" and "authorization" in request' 

591 

592 # malformed bearer 

593 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

594 with pytest.raises(grpc.RpcError) as e: 

595 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),)) 

596 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

597 assert e.value.details() == "Unauthorized" 

598 

599 # Invisible (deleted) user 

600 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

601 with pytest.raises(grpc.RpcError) as e: 

602 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(deleted_token),)) 

603 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

604 assert e.value.details() == "Unauthorized" 

605 

606 # Invalid (expired) session 

607 long_ago = now() - timedelta(weeks=100) 

608 with session_scope() as session: 

609 session.execute(update(UserSession).values(last_seen=long_ago).where(UserSession.token == token)) 

610 

611 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

612 with pytest.raises(grpc.RpcError) as e: 

613 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

614 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

615 assert e.value.details() == "Unauthorized" 

616 

617 # API key token, but session is for session cookie (probably impossible, but...) 

618 with session_scope() as session: 

619 session.execute(update(UserSession).values(last_seen=now(), is_api_key=True).where(UserSession.token == token)) 

620 

621 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

622 with pytest.raises(grpc.RpcError) as e: 

623 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

624 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

625 assert e.value.details() == "Unauthorized" 

626 

627 # Check that metadata are updated 

628 six_minutes_ago = now() - timedelta(minutes=6) 

629 with session_scope() as session: 

630 # Return the session to normal 

631 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one() 

632 user_session.is_api_key = False 

633 api_calls = user_session.api_calls 

634 

635 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

636 res4 = call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

637 assert res4.username == user.username 

638 

639 with session_scope() as session: 

640 user_session = session.execute(select(UserSession).where(UserSession.token == token)).scalar_one() 

641 assert user_session.api_calls == api_calls + 1 

642 assert user_session.last_seen > now() - timedelta(seconds=1) 

643 

644 # Simulate user inactivity, so last_active is updated on the next api call. 

645 session.execute(update(User).values(last_active=six_minutes_ago).where(User.id == user.id)) 

646 

647 # Check that last_active is updated if it wasn't updated in a while. 

648 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

649 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

650 

651 with session_scope() as session: 

652 last_active = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one() 

653 assert last_active > now() - timedelta(seconds=1) 

654 

655 # Check that last_active is untouched (since it was already updated recently) 

656 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

657 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

658 

659 with session_scope() as session: 

660 last_active_2 = session.execute(select(User.last_active).where(User.id == user.id)).scalar_one() 

661 assert last_active_2 == last_active 

662 

663 # Check that activity is split by IP. 

664 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

665 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token), ("x-couchers-real-ip", "1.1.1.1"))) 

666 

667 with session_scope() as session: 

668 api_calls = session.execute( 

669 select(UserActivity.api_calls).where(UserActivity.ip_address == "1.1.1.1") 

670 ).scalar_one() 

671 assert api_calls == 1 

672 

673 # Check that activity is split in time bins. 

674 # Update all UserActivity to be in the far past so that a new row is inserted on the next request. 

675 with session_scope() as session: 

676 session.execute(update(UserActivity).values(period=long_ago).where(UserActivity.user_id == user.id)) 

677 

678 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

679 call_rpc(empty_pb2.Empty(), metadata=(cookie_auth(token),)) 

680 

681 with session_scope() as session: 

682 api_calls = session.execute( 

683 select(UserActivity.api_calls) 

684 .where(UserActivity.user_id == user.id) 

685 .order_by(UserActivity.id.desc()) 

686 .limit(1) 

687 ).scalar_one() 

688 assert api_calls == 1 

689 

690 

691def test_tracing_interceptor_auth_cookies(db): 

692 user, token = generate_user() 

693 

694 account = Account() 

695 

696 rpc_def = { 

697 "rpc": account.GetAccountInfo, 

698 "service_name": "org.couchers.api.account.Account", 

699 "method_name": "GetAccountInfo", 

700 "interceptors": [CouchersMiddlewareInterceptor()], 

701 "request_type": empty_pb2.Empty, 

702 "response_type": account_pb2.GetAccountInfoRes, 

703 } 

704 

705 # with cookies 

706 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

707 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

708 assert res1.username == user.username 

709 

710 with session_scope() as session: 

711 trace = session.execute(select(APICall)).scalar_one() 

712 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

713 assert not trace.status_code 

714 assert trace.user_id == user.id 

715 assert not trace.is_api_key 

716 assert trace.request is not None 

717 assert len(trace.request) == 0 

718 assert not trace.traceback 

719 

720 

721def test_tracing_interceptor_auth_api_key(db): 

722 super_user, super_token = generate_user(is_superuser=True) 

723 user, token = generate_user() 

724 

725 with real_admin_session(super_token) as api: 

726 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

727 

728 with session_scope() as session: 

729 api_key = session.execute(select(UserSession.token).where(UserSession.is_api_key)).scalar_one() 

730 

731 account = Account() 

732 

733 rpc_def = { 

734 "rpc": account.GetAccountInfo, 

735 "service_name": "org.couchers.api.account.Account", 

736 "method_name": "GetAccountInfo", 

737 "interceptors": [CouchersMiddlewareInterceptor()], 

738 "request_type": empty_pb2.Empty, 

739 "response_type": account_pb2.GetAccountInfoRes, 

740 } 

741 

742 # with api key 

743 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

744 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

745 assert res1.username == user.username 

746 

747 with session_scope() as session: 

748 trace = session.execute( 

749 select(APICall).where(APICall.method == "/org.couchers.api.account.Account/GetAccountInfo") 

750 ).scalar_one() 

751 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

752 assert not trace.status_code 

753 assert trace.user_id == user.id 

754 assert trace.is_api_key 

755 assert trace.request is not None 

756 assert len(trace.request) == 0 

757 assert not trace.traceback 

758 

759 

760def test_auth_levels(db): 

761 def TestRpc(request, context, session): 

762 return empty_pb2.Empty() 

763 

764 def gen_args(service, method): 

765 return { 

766 "rpc": TestRpc, 

767 "service_name": service, 

768 "method_name": method, 

769 "interceptors": [CouchersMiddlewareInterceptor()], 

770 "request_type": empty_pb2.Empty, 

771 "response_type": empty_pb2.Empty, 

772 } 

773 

774 # superuser (note: superusers are automatically editors due to DB constraint) 

775 _, super_token = generate_user(is_superuser=True) 

776 # editor user 

777 _, editor_token = generate_user(is_editor=True) 

778 # normal user 

779 _, normal_token = generate_user() 

780 # jailed user 

781 _, jailed_token = generate_user(accepted_tos=0) 

782 # open user 

783 open_token = "" 

784 

785 # pick some rpcs here with the right auth levels 

786 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService") 

787 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo") 

788 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo") 

789 editor_args = gen_args("org.couchers.editor.Editor", "CreateCommunity") 

790 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails") 

791 

792 # pairs to check 

793 checks = [ 

794 # name, args, token, works?, code, message 

795 # open token only works on open servicers 

796 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

797 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

798 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

799 ("open x editor", open_token, editor_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

800 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

801 # jailed works on jailed and open 

802 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

803 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

804 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"), 

805 ("jailed x editor", jailed_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

806 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

807 # normal works on all but editor and admin 

808 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

809 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

810 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

811 ("normal x editor", normal_token, editor_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

812 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

813 # editor works on all but admin 

814 ("editor x open", editor_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

815 ("editor x jailed", editor_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

816 ("editor x secure", editor_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

817 ("editor x editor", editor_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

818 ("editor x admin", editor_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

819 # superuser works on all 

820 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

821 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

822 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

823 ("super x editor", super_token, editor_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

824 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

825 ] 

826 

827 for name, token, args, should_work, code, message in checks: 

828 print(f"Testing (token x args) = ({name}), {should_work=}") 

829 metadata = (("cookie", f"couchers-sesh={token}"),) 

830 with interceptor_dummy_api(**args) as call_rpc: 

831 if should_work: 

832 call_rpc(empty_pb2.Empty(), metadata=metadata) 

833 else: 

834 with pytest.raises(grpc.RpcError) as err: 

835 call_rpc(empty_pb2.Empty(), metadata=metadata) 

836 assert err.value.code() == code 

837 assert err.value.details() == message 

838 

839 # a non-existent RPC 

840 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing") 

841 

842 with interceptor_dummy_api(**nonexistent) as call_rpc: 

843 with pytest.raises(grpc.RpcError) as err: 

844 call_rpc(empty_pb2.Empty()) 

845 assert err.value.code() == grpc.StatusCode.UNIMPLEMENTED 

846 assert err.value.details() == "API call does not exist. Please refresh and try again." 

847 

848 # an RPC without a service level 

849 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation") 

850 

851 with interceptor_dummy_api(**invalid_args) as call_rpc: 

852 with pytest.raises(grpc.RpcError) as err: 

853 call_rpc(empty_pb2.Empty()) 

854 assert err.value.code() == grpc.StatusCode.INTERNAL 

855 assert err.value.details() == "Internal authentication error." 

856 

857 

858def test_parse_headers_with_session_cookie(): 

859 headers = {"cookie": "couchers-sesh=abc123; other-cookie=value"} 

860 result = parse_headers(headers) 

861 assert result.token == "abc123" 

862 assert result.is_api_key is False 

863 

864 

865def test_parse_headers_with_authorization_header(): 

866 headers = {"authorization": "Bearer abc123"} 

867 result = parse_headers(headers) 

868 assert result.token == "abc123" 

869 assert result.is_api_key is True 

870 

871 

872def test_parse_headers_with_both_cookie_and_authorization(): 

873 headers = {"cookie": "couchers-sesh=abc123", "authorization": "Bearer xyz789"} 

874 with pytest.raises(BadHeaders, match="Both cookies and authorization are present in headers"): 

875 parse_headers(headers) 

876 

877 

878def test_parse_headers_with_neither_cookie_nor_authorization(): 

879 result = parse_headers({}) 

880 assert result.token is None 

881 assert result.is_api_key is False 

882 

883 

884def test_parse_headers_with_all_optional_headers(): 

885 headers = { 

886 "cookie": "couchers-sesh=abc123; couchers-user-id=42; NEXT_LOCALE=en", 

887 "x-couchers-real-ip": "192.168.1.1", 

888 "user-agent": "TestAgent/1.0", 

889 } 

890 result = parse_headers(headers) 

891 assert result.token == "abc123" 

892 assert result.is_api_key is False 

893 assert result.ip_address == "192.168.1.1" 

894 assert result.user_agent == "TestAgent/1.0" 

895 assert result.ui_lang == "en" 

896 assert result.user_id == "42" 

897 

898 

899def test_parse_headers_with_bytes_ip_address(): 

900 headers: dict[str, str | bytes] = { 

901 "cookie": "couchers-sesh=abc123", 

902 "x-couchers-real-ip": b"192.168.1.1", 

903 } 

904 result = parse_headers(headers) 

905 assert result.ip_address is None 

906 

907 

908def test_parse_headers_with_bytes_user_agent(): 

909 headers: dict[str, str | bytes] = { 

910 "cookie": "couchers-sesh=abc123", 

911 "user-agent": b"TestAgent/1.0", 

912 } 

913 result = parse_headers(headers) 

914 assert result.user_agent is None 

915 

916 

917def test_parse_headers_malformed_authorization(): 

918 headers = {"authorization": "bearer abc123"} 

919 result = parse_headers(headers) 

920 assert result.token is None 

921 assert result.is_api_key is True 

922 

923 

924def test_find_auth_level_with_valid_service(): 

925 pool = get_descriptor_pool() 

926 

927 result = find_auth_level(pool, "/org.couchers.api.core.API/GetUser") 

928 assert result == annotations_pb2.AUTH_LEVEL_SECURE 

929 

930 

931def test_find_auth_level_with_nonexistent_service(): 

932 pool = get_descriptor_pool() 

933 

934 with pytest.raises(AbortError) as exc: 

935 find_auth_level(pool, "/org.couchers.nonexistent.Service/Method") 

936 assert exc.value.msg == NONEXISTENT_API_CALL_ERROR_MESSAGE 

937 assert exc.value.code == grpc.StatusCode.UNIMPLEMENTED 

938 

939 

940def test_find_auth_level_with_unknown_auth_level(): 

941 pool = Mock(spec=DescriptorPool) 

942 service_desc = Mock(spec=ServiceDescriptor) 

943 service_options = Mock() 

944 service_options.Extensions = {annotations_pb2.auth_level: annotations_pb2.AUTH_LEVEL_UNKNOWN} 

945 service_desc.GetOptions.return_value = service_options 

946 pool.FindServiceByName.return_value = service_desc 

947 

948 with pytest.raises(AbortError) as exc: 

949 find_auth_level(pool, "/org.couchers.api.core.API/GetUser") 

950 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE 

951 assert exc.value.code == grpc.StatusCode.INTERNAL 

952 

953 

954def test_validate_auth_level_with_unknown(): 

955 with pytest.raises(AbortError) as exc: 

956 validate_auth_level(annotations_pb2.AUTH_LEVEL_UNKNOWN) 

957 assert exc.value.msg == MISSING_AUTH_LEVEL_ERROR_MESSAGE 

958 assert exc.value.code == grpc.StatusCode.INTERNAL 

959 

960 

961def test_validate_auth_level_with_open(): 

962 validate_auth_level(annotations_pb2.AUTH_LEVEL_OPEN) 

963 

964 

965def test_validate_auth_level_with_jailed(): 

966 validate_auth_level(annotations_pb2.AUTH_LEVEL_JAILED) 

967 

968 

969def test_validate_auth_level_with_secure(): 

970 validate_auth_level(annotations_pb2.AUTH_LEVEL_SECURE) 

971 

972 

973def test_validate_auth_level_with_editor(): 

974 validate_auth_level(annotations_pb2.AUTH_LEVEL_EDITOR) 

975 

976 

977def test_validate_auth_level_with_admin(): 

978 validate_auth_level(annotations_pb2.AUTH_LEVEL_ADMIN) 

979 

980 

981def test_check_auth_open_service_without_auth(): 

982 check_permissions(None, annotations_pb2.AUTH_LEVEL_OPEN) 

983 

984 

985def test_check_auth_open_service_with_auth(): 

986 auth_info = UserAuthInfo( 

987 user_id=1, 

988 is_jailed=False, 

989 is_editor=False, 

990 is_superuser=False, 

991 token_expiry=now(), 

992 ui_language_preference="en", 

993 timezone="Etc/UTC", 

994 token="abc123", 

995 is_api_key=False, 

996 ) 

997 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_OPEN) 

998 

999 

1000def test_check_auth_secure_service_without_auth(): 

1001 with pytest.raises(AbortError): 

1002 check_permissions(None, annotations_pb2.AUTH_LEVEL_SECURE) 

1003 

1004 

1005def test_check_auth_secure_service_with_normal_auth(): 

1006 auth_info = UserAuthInfo( 

1007 user_id=1, 

1008 is_jailed=False, 

1009 is_editor=False, 

1010 is_superuser=False, 

1011 token_expiry=now(), 

1012 ui_language_preference="en", 

1013 timezone="Etc/UTC", 

1014 token="abc123", 

1015 is_api_key=False, 

1016 ) 

1017 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE) 

1018 

1019 

1020def test_check_auth_secure_service_with_jailed_user(): 

1021 auth_info = UserAuthInfo( 

1022 user_id=1, 

1023 is_jailed=True, 

1024 is_editor=False, 

1025 is_superuser=False, 

1026 token_expiry=now(), 

1027 ui_language_preference="en", 

1028 timezone="Etc/UTC", 

1029 token="abc123", 

1030 is_api_key=False, 

1031 ) 

1032 with pytest.raises(AbortError): 

1033 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_SECURE) 

1034 

1035 

1036def test_check_auth_jailed_service_with_jailed_user(): 

1037 auth_info = UserAuthInfo( 

1038 user_id=1, 

1039 is_jailed=True, 

1040 is_editor=False, 

1041 is_superuser=False, 

1042 token_expiry=now(), 

1043 ui_language_preference="en", 

1044 timezone="Etc/UTC", 

1045 token="abc123", 

1046 is_api_key=False, 

1047 ) 

1048 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_JAILED) 

1049 

1050 

1051def test_check_auth_jailed_service_without_auth(): 

1052 with pytest.raises(AbortError): 

1053 check_permissions(None, annotations_pb2.AUTH_LEVEL_JAILED) 

1054 

1055 

1056def test_check_auth_editor_service_without_editor(): 

1057 auth_info = UserAuthInfo( 

1058 user_id=1, 

1059 is_jailed=False, 

1060 is_editor=False, 

1061 is_superuser=False, 

1062 token_expiry=now(), 

1063 ui_language_preference="en", 

1064 timezone="Etc/UTC", 

1065 token="abc123", 

1066 is_api_key=False, 

1067 ) 

1068 with pytest.raises(AbortError): 

1069 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR) 

1070 

1071 

1072def test_check_auth_editor_service_with_editor(): 

1073 auth_info = UserAuthInfo( 

1074 user_id=1, 

1075 is_jailed=False, 

1076 is_editor=True, 

1077 is_superuser=False, 

1078 token_expiry=now(), 

1079 ui_language_preference="en", 

1080 timezone="Etc/UTC", 

1081 token="abc123", 

1082 is_api_key=False, 

1083 ) 

1084 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_EDITOR) 

1085 

1086 

1087def test_check_auth_admin_service_without_superuser(): 

1088 auth_info = UserAuthInfo( 

1089 user_id=1, 

1090 is_jailed=False, 

1091 is_editor=True, 

1092 is_superuser=False, 

1093 token_expiry=now(), 

1094 ui_language_preference="en", 

1095 timezone="Etc/UTC", 

1096 token="abc123", 

1097 is_api_key=False, 

1098 ) 

1099 with pytest.raises(AbortError): 

1100 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN) 

1101 

1102 

1103def test_check_auth_admin_service_with_superuser(): 

1104 auth_info = UserAuthInfo( 

1105 user_id=1, 

1106 is_jailed=False, 

1107 is_editor=True, 

1108 is_superuser=True, 

1109 token_expiry=now(), 

1110 ui_language_preference="en", 

1111 timezone="Etc/UTC", 

1112 token="abc123", 

1113 is_api_key=False, 

1114 ) 

1115 check_permissions(auth_info, annotations_pb2.AUTH_LEVEL_ADMIN) 

1116 

1117 

1118def test_check_auth_admin_service_without_auth(): 

1119 with pytest.raises(AbortError): 

1120 check_permissions(None, annotations_pb2.AUTH_LEVEL_ADMIN) 

1121 

1122 

1123def test_parse_sofa_cookie_valid(): 

1124 sofa_value, cookie_string = generate_sofa_cookie() 

1125 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1126 

1127 headers = {"cookie": f"sofa={cookie_value}"} 

1128 result = parse_sofa_cookie(headers) 

1129 assert result == sofa_value 

1130 

1131 

1132def test_parse_sofa_cookie_missing(): 

1133 headers = {"cookie": "other-cookie=value"} 

1134 result = parse_sofa_cookie(headers) 

1135 assert result is None 

1136 

1137 

1138def test_parse_sofa_cookie_no_cookies(): 

1139 headers: dict[str, str] = {} 

1140 result = parse_sofa_cookie(headers) 

1141 assert result is None 

1142 

1143 

1144def test_parse_sofa_cookie_invalid_base64(): 

1145 headers = {"cookie": "sofa=not-valid-base64!!!"} 

1146 result = parse_sofa_cookie(headers) 

1147 assert result is None 

1148 

1149 

1150def test_parse_sofa_cookie_invalid_encryption(): 

1151 headers = {"cookie": f"sofa={b64encode(b'invalid encrypted data')}"} 

1152 result = parse_sofa_cookie(headers) 

1153 assert result is None 

1154 

1155 

1156def test_parse_sofa_cookie_invalid_proto(): 

1157 encrypted = simple_encrypt("sofa_cookie", b"not a valid proto") 

1158 headers = {"cookie": f"sofa={b64encode(encrypted)}"} 

1159 result = parse_sofa_cookie(headers) 

1160 assert result is not None or result is None 

1161 

1162 

1163def test_generate_sofa_cookie(): 

1164 sofa_value, cookie_string = generate_sofa_cookie() 

1165 

1166 assert sofa_value 

1167 assert isinstance(sofa_value, str) 

1168 assert len(sofa_value) > 20 

1169 

1170 assert "sofa=" in cookie_string 

1171 assert "expires=" in cookie_string.lower() 

1172 

1173 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1174 headers = {"cookie": f"sofa={cookie_value}"} 

1175 parsed_value = parse_sofa_cookie(headers) 

1176 assert parsed_value == sofa_value 

1177 

1178 

1179def test_parse_headers_with_sofa_cookie(): 

1180 sofa_value, cookie_string = generate_sofa_cookie() 

1181 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1182 

1183 headers = { 

1184 "cookie": f"couchers-sesh=abc123; sofa={cookie_value}", 

1185 } 

1186 result = parse_headers(headers) 

1187 assert result.token == "abc123" 

1188 assert result.sofa == sofa_value 

1189 

1190 

1191def test_parse_headers_without_sofa_cookie(): 

1192 headers = { 

1193 "cookie": "couchers-sesh=abc123", 

1194 } 

1195 result = parse_headers(headers) 

1196 assert result.token == "abc123" 

1197 assert result.sofa is None 

1198 

1199 

1200def test_sofa_cookie_logged_new(db): 

1201 def TestRpc(request, context, session): 

1202 return empty_pb2.Empty() 

1203 

1204 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

1205 call_rpc(empty_pb2.Empty()) 

1206 

1207 with session_scope() as session: 

1208 trace = session.execute(select(APICall)).scalar_one() 

1209 assert trace.sofa is not None 

1210 assert len(trace.sofa) > 20 

1211 

1212 

1213def test_sofa_cookie_logged_existing(db): 

1214 sofa_value, cookie_string = generate_sofa_cookie() 

1215 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1216 

1217 def TestRpc(request, context, session): 

1218 return empty_pb2.Empty() 

1219 

1220 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

1221 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),)) 

1222 

1223 with session_scope() as session: 

1224 trace = session.execute(select(APICall)).scalar_one() 

1225 assert trace.sofa == sofa_value 

1226 

1227 

1228def test_sofa_cookie_logged_invalid_generates_new(db): 

1229 def TestRpc(request, context, session): 

1230 return empty_pb2.Empty() 

1231 

1232 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

1233 call_rpc(empty_pb2.Empty(), metadata=(("cookie", "sofa=invalid-cookie-value"),)) 

1234 

1235 with session_scope() as session: 

1236 trace = session.execute(select(APICall)).scalar_one() 

1237 assert trace.sofa is not None 

1238 assert trace.sofa != "invalid-cookie-value" 

1239 assert len(trace.sofa) > 20 

1240 

1241 

1242def test_sofa_cookie_with_authenticated_user(db): 

1243 user, token = generate_user() 

1244 sofa_value, cookie_string = generate_sofa_cookie() 

1245 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1246 

1247 account = Account() 

1248 

1249 rpc_def = { 

1250 "rpc": account.GetAccountInfo, 

1251 "service_name": "org.couchers.api.account.Account", 

1252 "method_name": "GetAccountInfo", 

1253 "interceptors": [CouchersMiddlewareInterceptor()], 

1254 "request_type": empty_pb2.Empty, 

1255 "response_type": account_pb2.GetAccountInfoRes, 

1256 } 

1257 

1258 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

1259 res = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}; sofa={cookie_value}"),)) 

1260 assert res.username == user.username 

1261 

1262 with session_scope() as session: 

1263 trace = session.execute(select(APICall)).scalar_one() 

1264 assert trace.user_id == user.id 

1265 assert trace.sofa == sofa_value 

1266 

1267 

1268def test_sofa_cookie_persists_on_exception(db): 

1269 sofa_value, cookie_string = generate_sofa_cookie() 

1270 cookie_value = cookie_string.split("=", 1)[1].split(";")[0] 

1271 

1272 def TestRpc(request, context, session): 

1273 raise Exception("Test error") 

1274 

1275 with interceptor_dummy_api(TestRpc, interceptors=[CouchersMiddlewareInterceptor()]) as call_rpc: 

1276 with pytest.raises(Exception, match="Test error"): 

1277 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"sofa={cookie_value}"),)) 

1278 

1279 with session_scope() as session: 

1280 trace = session.execute(select(APICall)).scalar_one() 

1281 assert trace.sofa == sofa_value 

1282 assert trace.traceback is not None 

1283 assert "Test error" in trace.traceback