Coverage for src/tests/test_interceptors.py: 100%

281 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-04-16 15:13 +0000

1from concurrent import futures 

2from contextlib import contextmanager 

3 

4import grpc 

5import pytest 

6from google.protobuf import empty_pb2 

7 

8from couchers import errors 

9from couchers.crypto import random_hex 

10from couchers.db import session_scope 

11from couchers.interceptors import ( 

12 AuthValidatorInterceptor, 

13 CookieInterceptor, 

14 ErrorSanitizationInterceptor, 

15 SessionInterceptor, 

16 TracingInterceptor, 

17) 

18from couchers.metrics import servicer_duration_histogram 

19from couchers.models import APICall, UserSession 

20from couchers.servicers.account import Account 

21from couchers.servicers.api import API 

22from couchers.sql import couchers_select as select 

23from proto import account_pb2, admin_pb2, api_pb2, auth_pb2 

24from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa 

25 

26 

27@pytest.fixture(autouse=True) 

28def _(testconfig): 

29 pass 

30 

31 

32@contextmanager 

33def interceptor_dummy_api( 

34 rpc, 

35 interceptors, 

36 service_name="testing.Test", 

37 method_name="TestRpc", 

38 request_type=empty_pb2.Empty, 

39 response_type=empty_pb2.Empty, 

40 creds=None, 

41): 

42 with futures.ThreadPoolExecutor(1) as executor: 

43 server = grpc.server(executor, interceptors=interceptors) 

44 port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) 

45 

46 # manually add the handler 

47 rpc_method_handlers = { 

48 method_name: grpc.unary_unary_rpc_method_handler( 

49 rpc, 

50 request_deserializer=request_type.FromString, 

51 response_serializer=response_type.SerializeToString, 

52 ) 

53 } 

54 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) 

55 server.add_generic_rpc_handlers((generic_handler,)) 

56 server.start() 

57 

58 try: 

59 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: 

60 call_rpc = channel.unary_unary( 

61 f"/{service_name}/{method_name}", 

62 request_serializer=request_type.SerializeToString, 

63 response_deserializer=response_type.FromString, 

64 ) 

65 yield call_rpc 

66 finally: 

67 server.stop(None).wait() 

68 

69 

70def _get_histogram_labels_value(method, logged_in, exception, code): 

71 metrics = servicer_duration_histogram.collect() 

72 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0] 

73 histogram_counts = [ 

74 s 

75 for s in servicer_histogram.samples 

76 if s.name == "couchers_servicer_duration_seconds_count" 

77 and s.labels["method"] == method 

78 and s.labels["logged_in"] == logged_in 

79 and s.labels["code"] == code 

80 and s.labels["exception"] == exception 

81 ] 

82 if len(histogram_counts) == 0: 

83 return 0 

84 return histogram_counts[0].value 

85 

86 

87def test_logging_interceptor_ok(): 

88 def TestRpc(request, context): 

89 return empty_pb2.Empty() 

90 

91 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

92 call_rpc(empty_pb2.Empty()) 

93 

94 

95def test_logging_interceptor_all_ignored(): 

96 # error codes that should not be touched by the interceptor 

97 pass_through_status_codes = [ 

98 # we can't abort with OK 

99 # grpc.StatusCode.OK, 

100 grpc.StatusCode.CANCELLED, 

101 grpc.StatusCode.UNKNOWN, 

102 grpc.StatusCode.INVALID_ARGUMENT, 

103 grpc.StatusCode.DEADLINE_EXCEEDED, 

104 grpc.StatusCode.NOT_FOUND, 

105 grpc.StatusCode.ALREADY_EXISTS, 

106 grpc.StatusCode.PERMISSION_DENIED, 

107 grpc.StatusCode.UNAUTHENTICATED, 

108 grpc.StatusCode.RESOURCE_EXHAUSTED, 

109 grpc.StatusCode.FAILED_PRECONDITION, 

110 grpc.StatusCode.ABORTED, 

111 grpc.StatusCode.OUT_OF_RANGE, 

112 grpc.StatusCode.UNIMPLEMENTED, 

113 grpc.StatusCode.INTERNAL, 

114 grpc.StatusCode.UNAVAILABLE, 

115 grpc.StatusCode.DATA_LOSS, 

116 ] 

117 

118 for status_code in pass_through_status_codes: 

119 message = random_hex() 

120 

121 def TestRpc(request, context): 

122 context.abort(status_code, message) # noqa: B023 

123 

124 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

125 with pytest.raises(grpc.RpcError) as e: 

126 call_rpc(empty_pb2.Empty()) 

127 assert e.value.code() == status_code 

128 assert e.value.details() == message 

129 

130 

131def test_logging_interceptor_assertion(): 

132 def TestRpc(request, context): 

133 raise AssertionError() 

134 

135 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

136 with pytest.raises(grpc.RpcError) as e: 

137 call_rpc(empty_pb2.Empty()) 

138 assert e.value.code() == grpc.StatusCode.INTERNAL 

139 assert e.value.details() == errors.UNKNOWN_ERROR 

140 

141 

142def test_logging_interceptor_div0(): 

143 def TestRpc(request, context): 

144 1 / 0 # noqa: B018 

145 

146 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

147 with pytest.raises(grpc.RpcError) as e: 

148 call_rpc(empty_pb2.Empty()) 

149 assert e.value.code() == grpc.StatusCode.INTERNAL 

150 assert e.value.details() == errors.UNKNOWN_ERROR 

151 

152 

153def test_logging_interceptor_raise(): 

154 def TestRpc(request, context): 

155 raise Exception() 

156 

157 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

158 with pytest.raises(grpc.RpcError) as e: 

159 call_rpc(empty_pb2.Empty()) 

160 assert e.value.code() == grpc.StatusCode.INTERNAL 

161 assert e.value.details() == errors.UNKNOWN_ERROR 

162 

163 

164def test_logging_interceptor_raise_custom(): 

165 class _TestingException(Exception): 

166 pass 

167 

168 def TestRpc(request, context): 

169 raise _TestingException("This is a custom exception") 

170 

171 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

172 with pytest.raises(grpc.RpcError) as e: 

173 call_rpc(empty_pb2.Empty()) 

174 assert e.value.code() == grpc.StatusCode.INTERNAL 

175 assert e.value.details() == errors.UNKNOWN_ERROR 

176 

177 

178def test_tracing_interceptor_ok_open(db): 

179 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") 

180 

181 def TestRpc(request, context): 

182 return empty_pb2.Empty() 

183 

184 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc: 

185 call_rpc(empty_pb2.Empty()) 

186 

187 with session_scope() as session: 

188 trace = session.execute(select(APICall)).scalar_one() 

189 assert trace.method == "/testing.Test/TestRpc" 

190 assert not trace.status_code 

191 assert not trace.user_id 

192 assert len(trace.request) == 0 

193 assert len(trace.response) == 0 

194 assert not trace.traceback 

195 

196 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") == val + 1 

197 

198 

199def test_tracing_interceptor_sensitive(db): 

200 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") 

201 

202 def TestRpc(request, context): 

203 return auth_pb2.AuthReq(user="this is not secret", password="this is secret") 

204 

205 with interceptor_dummy_api( 

206 TestRpc, 

207 interceptors=[TracingInterceptor()], 

208 request_type=auth_pb2.SignupFlowReq, 

209 response_type=auth_pb2.AuthReq, 

210 ) as call_rpc: 

211 call_rpc( 

212 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

213 ) 

214 

215 with session_scope() as session: 

216 trace = session.execute(select(APICall)).scalar_one() 

217 assert trace.method == "/testing.Test/TestRpc" 

218 assert not trace.status_code 

219 assert not trace.user_id 

220 assert not trace.traceback 

221 req = auth_pb2.SignupFlowReq.FromString(trace.request) 

222 assert not req.account.password 

223 assert req.account.username == "not removed" 

224 res = auth_pb2.AuthReq.FromString(trace.response) 

225 assert res.user == "this is not secret" 

226 assert not res.password 

227 

228 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "", "") == val + 1 

229 

230 

231def test_tracing_interceptor_sensitive_ping(db): 

232 user, token = generate_user() 

233 

234 with interceptor_dummy_api( 

235 API().GetUser, 

236 interceptors=[TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

237 request_type=api_pb2.GetUserReq, 

238 response_type=api_pb2.User, 

239 service_name="org.couchers.api.core.API", 

240 method_name="GetUser", 

241 ) as call_rpc: 

242 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),)) 

243 

244 

245def test_tracing_interceptor_exception(db): 

246 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "") 

247 

248 def TestRpc(request, context): 

249 raise Exception("Some error message") 

250 

251 with interceptor_dummy_api( 

252 TestRpc, 

253 interceptors=[TracingInterceptor()], 

254 request_type=auth_pb2.SignupAccount, 

255 response_type=auth_pb2.AuthReq, 

256 ) as call_rpc: 

257 with pytest.raises(Exception, match="Some error message"): 

258 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

259 

260 with session_scope() as session: 

261 trace = session.execute(select(APICall)).scalar_one() 

262 assert trace.method == "/testing.Test/TestRpc" 

263 assert not trace.status_code 

264 assert not trace.user_id 

265 assert "Some error message" in trace.traceback 

266 req = auth_pb2.SignupAccount.FromString(trace.request) 

267 assert not req.password 

268 assert req.username == "not removed" 

269 assert not trace.response 

270 

271 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "") == val + 1 

272 

273 

274def test_tracing_interceptor_abort(db): 

275 val = _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION") 

276 

277 def TestRpc(request, context): 

278 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort") 

279 

280 with interceptor_dummy_api( 

281 TestRpc, 

282 interceptors=[TracingInterceptor()], 

283 request_type=auth_pb2.SignupAccount, 

284 response_type=auth_pb2.AuthReq, 

285 ) as call_rpc: 

286 with pytest.raises(Exception, match="now a grpc abort"): 

287 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

288 

289 with session_scope() as session: 

290 trace = session.execute(select(APICall)).scalar_one() 

291 assert trace.method == "/testing.Test/TestRpc" 

292 assert trace.status_code == "FAILED_PRECONDITION" 

293 assert not trace.user_id 

294 assert "now a grpc abort" in trace.traceback 

295 req = auth_pb2.SignupAccount.FromString(trace.request) 

296 assert not req.password 

297 assert req.username == "not removed" 

298 assert not trace.response 

299 

300 assert _get_histogram_labels_value("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION") == val + 1 

301 

302 

303def test_auth_interceptor(db): 

304 super_user, super_token = generate_user(is_superuser=True) 

305 user, token = generate_user() 

306 

307 with real_admin_session(super_token) as api: 

308 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

309 

310 with session_scope() as session: 

311 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

312 api_key = api_session.token 

313 

314 account = Account() 

315 

316 rpc_def = { 

317 "rpc": account.GetAccountInfo, 

318 "service_name": "org.couchers.api.account.Account", 

319 "method_name": "GetAccountInfo", 

320 "interceptors": [AuthValidatorInterceptor(), CookieInterceptor(), SessionInterceptor()], 

321 "request_type": empty_pb2.Empty, 

322 "response_type": account_pb2.GetAccountInfoRes, 

323 } 

324 

325 # no creds, no go for secure APIs 

326 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

327 with pytest.raises(grpc.RpcError) as e: 

328 call_rpc(empty_pb2.Empty()) 

329 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

330 assert e.value.details() == "Unauthorized" 

331 

332 # can auth with cookie 

333 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

334 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

335 assert res1.username == user.username 

336 

337 # can't auth with wrong cookie 

338 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

339 with pytest.raises(grpc.RpcError) as e: 

340 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),)) 

341 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

342 assert e.value.details() == "Unauthorized" 

343 

344 # can auth with api key 

345 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

346 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

347 assert res2.username == user.username 

348 

349 # can't auth with wrong api key 

350 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

351 with pytest.raises(grpc.RpcError) as e: 

352 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),)) 

353 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

354 assert e.value.details() == "Unauthorized" 

355 

356 # can auth with grpc helper (they do the same as above) 

357 comp_creds = grpc.composite_channel_credentials( 

358 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key) 

359 ) 

360 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc: 

361 res3 = call_rpc(empty_pb2.Empty()) 

362 assert res3.username == user.username 

363 

364 # can't auth with both 

365 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

366 with pytest.raises(grpc.RpcError) as e: 

367 call_rpc( 

368 empty_pb2.Empty(), 

369 metadata=( 

370 ("cookie", f"couchers-sesh={token}"), 

371 ("authorization", f"Bearer {api_key}"), 

372 ), 

373 ) 

374 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

375 assert e.value.details() == 'Both "cookie" and "authorization" in request' 

376 

377 # malformed bearer 

378 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

379 with pytest.raises(grpc.RpcError) as e: 

380 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),)) 

381 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

382 assert e.value.details() == "Unauthorized" 

383 

384 

385def test_tracing_interceptor_auth_cookies(db): 

386 user, token = generate_user() 

387 

388 account = Account() 

389 

390 rpc_def = { 

391 "rpc": account.GetAccountInfo, 

392 "service_name": "org.couchers.api.account.Account", 

393 "method_name": "GetAccountInfo", 

394 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

395 "request_type": empty_pb2.Empty, 

396 "response_type": account_pb2.GetAccountInfoRes, 

397 } 

398 

399 # with cookies 

400 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

401 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

402 assert res1.username == user.username 

403 

404 with session_scope() as session: 

405 trace = session.execute(select(APICall)).scalar_one() 

406 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

407 assert not trace.status_code 

408 assert trace.user_id == user.id 

409 assert not trace.is_api_key 

410 assert len(trace.request) == 0 

411 assert not trace.traceback 

412 

413 

414def test_tracing_interceptor_auth_api_key(db): 

415 super_user, super_token = generate_user(is_superuser=True) 

416 user, token = generate_user() 

417 

418 with real_admin_session(super_token) as api: 

419 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

420 

421 with session_scope() as session: 

422 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

423 api_key = api_session.token 

424 

425 account = Account() 

426 

427 rpc_def = { 

428 "rpc": account.GetAccountInfo, 

429 "service_name": "org.couchers.api.account.Account", 

430 "method_name": "GetAccountInfo", 

431 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

432 "request_type": empty_pb2.Empty, 

433 "response_type": account_pb2.GetAccountInfoRes, 

434 } 

435 

436 # with api key 

437 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

438 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

439 assert res1.username == user.username 

440 

441 with session_scope() as session: 

442 trace = session.execute(select(APICall)).scalar_one() 

443 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

444 assert not trace.status_code 

445 assert trace.user_id == user.id 

446 assert trace.is_api_key 

447 assert len(trace.request) == 0 

448 assert not trace.traceback 

449 

450 

451def test_auth_levels(db): 

452 def TestRpc(request, context): 

453 return empty_pb2.Empty() 

454 

455 def gen_args(service, method): 

456 return { 

457 "rpc": TestRpc, 

458 "service_name": service, 

459 "method_name": method, 

460 "interceptors": [AuthValidatorInterceptor()], 

461 "request_type": empty_pb2.Empty, 

462 "response_type": empty_pb2.Empty, 

463 } 

464 

465 # superuser 

466 _, super_token = generate_user(is_superuser=True) 

467 # normal user 

468 _, normal_token = generate_user() 

469 # jailed user 

470 _, jailed_token = generate_user(accepted_tos=0) 

471 # open user 

472 open_token = "" 

473 

474 # pick some rpcs here with the right auth levels 

475 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService") 

476 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo") 

477 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo") 

478 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails") 

479 

480 # pairs to check 

481 checks = [ 

482 # name, args, token, works?, code, message 

483 # open token only works on open servicers 

484 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

485 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

486 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

487 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

488 # jailed works on jailed and open 

489 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

490 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

491 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"), 

492 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

493 # normal works on all but admin 

494 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

495 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

496 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

497 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

498 # superuser works on all 

499 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

500 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

501 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

502 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

503 ] 

504 

505 for name, token, args, should_work, code, message in checks: 

506 print(f"Testing (token x args) = ({name}), {should_work=}") 

507 metadata = (("cookie", f"couchers-sesh={token}"),) 

508 with interceptor_dummy_api(**args) as call_rpc: 

509 if should_work: 

510 call_rpc(empty_pb2.Empty(), metadata=metadata) 

511 else: 

512 with pytest.raises(grpc.RpcError) as e: 

513 call_rpc(empty_pb2.Empty(), metadata=metadata) 

514 assert e.value.code() == code 

515 assert e.value.details() == message 

516 

517 # a non-existent RPC 

518 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing") 

519 

520 with interceptor_dummy_api(**nonexistent) as call_rpc: 

521 with pytest.raises(Exception) as e: 

522 call_rpc(empty_pb2.Empty()) 

523 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED 

524 assert e.value.details() == "API call does not exist. Please refresh and try again." 

525 

526 # an RPC without a service level 

527 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation") 

528 

529 with interceptor_dummy_api(**invalid_args) as call_rpc: 

530 with pytest.raises(Exception) as e: 

531 call_rpc(empty_pb2.Empty()) 

532 assert e.value.code() == grpc.StatusCode.INTERNAL 

533 assert e.value.details() == "Internal authentication error."