Coverage for src/tests/test_interceptors.py: 100%

276 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-10-15 13:03 +0000

1from concurrent import futures 

2from contextlib import contextmanager 

3 

4import grpc 

5import pytest 

6from google.protobuf import empty_pb2 

7 

8from couchers import errors 

9from couchers.crypto import random_hex 

10from couchers.db import session_scope 

11from couchers.interceptors import ( 

12 AuthValidatorInterceptor, 

13 CookieInterceptor, 

14 ErrorSanitizationInterceptor, 

15 SessionInterceptor, 

16 TracingInterceptor, 

17) 

18from couchers.metrics import servicer_duration_histogram 

19from couchers.models import APICall, UserSession 

20from couchers.servicers.account import Account 

21from couchers.servicers.api import API 

22from couchers.sql import couchers_select as select 

23from proto import account_pb2, admin_pb2, api_pb2, auth_pb2 

24from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa 

25 

26 

27@pytest.fixture(autouse=True) 

28def _(testconfig): 

29 pass 

30 

31 

32@contextmanager 

33def interceptor_dummy_api( 

34 rpc, 

35 interceptors, 

36 service_name="testing.Test", 

37 method_name="TestRpc", 

38 request_type=empty_pb2.Empty, 

39 response_type=empty_pb2.Empty, 

40 creds=None, 

41): 

42 with futures.ThreadPoolExecutor(1) as executor: 

43 server = grpc.server(executor, interceptors=interceptors) 

44 port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) 

45 

46 # manually add the handler 

47 rpc_method_handlers = { 

48 method_name: grpc.unary_unary_rpc_method_handler( 

49 rpc, 

50 request_deserializer=request_type.FromString, 

51 response_serializer=response_type.SerializeToString, 

52 ) 

53 } 

54 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) 

55 server.add_generic_rpc_handlers((generic_handler,)) 

56 server.start() 

57 

58 try: 

59 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: 

60 call_rpc = channel.unary_unary( 

61 f"/{service_name}/{method_name}", 

62 request_serializer=request_type.SerializeToString, 

63 response_deserializer=response_type.FromString, 

64 ) 

65 yield call_rpc 

66 finally: 

67 server.stop(None).wait() 

68 

69 

70def _check_histogram_labels(method, logged_in, exception, code, count): 

71 metrics = servicer_duration_histogram.collect() 

72 servicer_histogram = [m for m in metrics if m.name == "couchers_servicer_duration_seconds"][0] 

73 histogram_count = [ 

74 s 

75 for s in servicer_histogram.samples 

76 if s.name == "couchers_servicer_duration_seconds_count" 

77 and s.labels["method"] == method 

78 and s.labels["logged_in"] == logged_in 

79 and s.labels["code"] == code 

80 and s.labels["exception"] == exception 

81 ][0] 

82 assert histogram_count.value == count 

83 servicer_duration_histogram.clear() 

84 

85 

86def test_logging_interceptor_ok(): 

87 def TestRpc(request, context): 

88 return empty_pb2.Empty() 

89 

90 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

91 call_rpc(empty_pb2.Empty()) 

92 

93 

94def test_logging_interceptor_all_ignored(): 

95 # error codes that should not be touched by the interceptor 

96 pass_through_status_codes = [ 

97 # we can't abort with OK 

98 # grpc.StatusCode.OK, 

99 grpc.StatusCode.CANCELLED, 

100 grpc.StatusCode.UNKNOWN, 

101 grpc.StatusCode.INVALID_ARGUMENT, 

102 grpc.StatusCode.DEADLINE_EXCEEDED, 

103 grpc.StatusCode.NOT_FOUND, 

104 grpc.StatusCode.ALREADY_EXISTS, 

105 grpc.StatusCode.PERMISSION_DENIED, 

106 grpc.StatusCode.UNAUTHENTICATED, 

107 grpc.StatusCode.RESOURCE_EXHAUSTED, 

108 grpc.StatusCode.FAILED_PRECONDITION, 

109 grpc.StatusCode.ABORTED, 

110 grpc.StatusCode.OUT_OF_RANGE, 

111 grpc.StatusCode.UNIMPLEMENTED, 

112 grpc.StatusCode.INTERNAL, 

113 grpc.StatusCode.UNAVAILABLE, 

114 grpc.StatusCode.DATA_LOSS, 

115 ] 

116 

117 for status_code in pass_through_status_codes: 

118 message = random_hex() 

119 

120 def TestRpc(request, context): 

121 context.abort(status_code, message) # noqa: B023 

122 

123 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

124 with pytest.raises(grpc.RpcError) as e: 

125 call_rpc(empty_pb2.Empty()) 

126 assert e.value.code() == status_code 

127 assert e.value.details() == message 

128 

129 

130def test_logging_interceptor_assertion(): 

131 def TestRpc(request, context): 

132 raise AssertionError() 

133 

134 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

135 with pytest.raises(grpc.RpcError) as e: 

136 call_rpc(empty_pb2.Empty()) 

137 assert e.value.code() == grpc.StatusCode.INTERNAL 

138 assert e.value.details() == errors.UNKNOWN_ERROR 

139 

140 

141def test_logging_interceptor_div0(): 

142 def TestRpc(request, context): 

143 1 / 0 # noqa: B018 

144 

145 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

146 with pytest.raises(grpc.RpcError) as e: 

147 call_rpc(empty_pb2.Empty()) 

148 assert e.value.code() == grpc.StatusCode.INTERNAL 

149 assert e.value.details() == errors.UNKNOWN_ERROR 

150 

151 

152def test_logging_interceptor_raise(): 

153 def TestRpc(request, context): 

154 raise Exception() 

155 

156 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

157 with pytest.raises(grpc.RpcError) as e: 

158 call_rpc(empty_pb2.Empty()) 

159 assert e.value.code() == grpc.StatusCode.INTERNAL 

160 assert e.value.details() == errors.UNKNOWN_ERROR 

161 

162 

163def test_logging_interceptor_raise_custom(): 

164 class _TestingException(Exception): 

165 pass 

166 

167 def TestRpc(request, context): 

168 raise _TestingException("This is a custom exception") 

169 

170 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

171 with pytest.raises(grpc.RpcError) as e: 

172 call_rpc(empty_pb2.Empty()) 

173 assert e.value.code() == grpc.StatusCode.INTERNAL 

174 assert e.value.details() == errors.UNKNOWN_ERROR 

175 

176 

177def test_tracing_interceptor_ok_open(db): 

178 def TestRpc(request, context): 

179 return empty_pb2.Empty() 

180 

181 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc: 

182 call_rpc(empty_pb2.Empty()) 

183 

184 with session_scope() as session: 

185 trace = session.execute(select(APICall)).scalar_one() 

186 assert trace.method == "/testing.Test/TestRpc" 

187 assert not trace.status_code 

188 assert not trace.user_id 

189 assert len(trace.request) == 0 

190 assert len(trace.response) == 0 

191 assert not trace.traceback 

192 

193 _check_histogram_labels("/testing.Test/TestRpc", "False", "", "", 1) 

194 

195 

196def test_tracing_interceptor_sensitive(db): 

197 def TestRpc(request, context): 

198 return auth_pb2.AuthReq(user="this is not secret", password="this is secret") 

199 

200 with interceptor_dummy_api( 

201 TestRpc, 

202 interceptors=[TracingInterceptor()], 

203 request_type=auth_pb2.SignupFlowReq, 

204 response_type=auth_pb2.AuthReq, 

205 ) as call_rpc: 

206 call_rpc( 

207 auth_pb2.SignupFlowReq(account=auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

208 ) 

209 

210 with session_scope() as session: 

211 trace = session.execute(select(APICall)).scalar_one() 

212 assert trace.method == "/testing.Test/TestRpc" 

213 assert not trace.status_code 

214 assert not trace.user_id 

215 assert not trace.traceback 

216 req = auth_pb2.SignupFlowReq.FromString(trace.request) 

217 assert not req.account.password 

218 assert req.account.username == "not removed" 

219 res = auth_pb2.AuthReq.FromString(trace.response) 

220 assert res.user == "this is not secret" 

221 assert not res.password 

222 

223 _check_histogram_labels("/testing.Test/TestRpc", "False", "", "", 1) 

224 

225 

226def test_tracing_interceptor_sensitive_ping(db): 

227 user, token = generate_user() 

228 

229 with interceptor_dummy_api( 

230 API().GetUser, 

231 interceptors=[TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

232 request_type=api_pb2.GetUserReq, 

233 response_type=api_pb2.User, 

234 service_name="org.couchers.api.core.API", 

235 method_name="GetUser", 

236 ) as call_rpc: 

237 call_rpc(api_pb2.GetUserReq(user=user.username), metadata=(("cookie", f"couchers-sesh={token}"),)) 

238 

239 

240def test_tracing_interceptor_exception(db): 

241 def TestRpc(request, context): 

242 raise Exception("Some error message") 

243 

244 with interceptor_dummy_api( 

245 TestRpc, 

246 interceptors=[TracingInterceptor()], 

247 request_type=auth_pb2.SignupAccount, 

248 response_type=auth_pb2.AuthReq, 

249 ) as call_rpc: 

250 with pytest.raises(Exception, match="Some error message"): 

251 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

252 

253 with session_scope() as session: 

254 trace = session.execute(select(APICall)).scalar_one() 

255 assert trace.method == "/testing.Test/TestRpc" 

256 assert not trace.status_code 

257 assert not trace.user_id 

258 assert "Some error message" in trace.traceback 

259 req = auth_pb2.SignupAccount.FromString(trace.request) 

260 assert not req.password 

261 assert req.username == "not removed" 

262 assert not trace.response 

263 

264 _check_histogram_labels("/testing.Test/TestRpc", "False", "Exception", "", 1) 

265 

266 

267def test_tracing_interceptor_abort(db): 

268 def TestRpc(request, context): 

269 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort") 

270 

271 with interceptor_dummy_api( 

272 TestRpc, 

273 interceptors=[TracingInterceptor()], 

274 request_type=auth_pb2.SignupAccount, 

275 response_type=auth_pb2.AuthReq, 

276 ) as call_rpc: 

277 with pytest.raises(Exception, match="now a grpc abort"): 

278 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

279 

280 with session_scope() as session: 

281 trace = session.execute(select(APICall)).scalar_one() 

282 assert trace.method == "/testing.Test/TestRpc" 

283 assert trace.status_code == "FAILED_PRECONDITION" 

284 assert not trace.user_id 

285 assert "now a grpc abort" in trace.traceback 

286 req = auth_pb2.SignupAccount.FromString(trace.request) 

287 assert not req.password 

288 assert req.username == "not removed" 

289 assert not trace.response 

290 

291 _check_histogram_labels("/testing.Test/TestRpc", "False", "Exception", "FAILED_PRECONDITION", 1) 

292 

293 

294def test_auth_interceptor(db): 

295 super_user, super_token = generate_user(is_superuser=True) 

296 user, token = generate_user() 

297 

298 with real_admin_session(super_token) as api: 

299 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

300 

301 with session_scope() as session: 

302 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

303 api_key = api_session.token 

304 

305 account = Account() 

306 

307 rpc_def = { 

308 "rpc": account.GetAccountInfo, 

309 "service_name": "org.couchers.api.account.Account", 

310 "method_name": "GetAccountInfo", 

311 "interceptors": [AuthValidatorInterceptor(), CookieInterceptor(), SessionInterceptor()], 

312 "request_type": empty_pb2.Empty, 

313 "response_type": account_pb2.GetAccountInfoRes, 

314 } 

315 

316 # no creds, no go for secure APIs 

317 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

318 with pytest.raises(grpc.RpcError) as e: 

319 call_rpc(empty_pb2.Empty()) 

320 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

321 assert e.value.details() == "Unauthorized" 

322 

323 # can auth with cookie 

324 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

325 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

326 assert res1.username == user.username 

327 

328 # can't auth with wrong cookie 

329 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

330 with pytest.raises(grpc.RpcError) as e: 

331 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),)) 

332 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

333 assert e.value.details() == "Unauthorized" 

334 

335 # can auth with api key 

336 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

337 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

338 assert res2.username == user.username 

339 

340 # can't auth with wrong api key 

341 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

342 with pytest.raises(grpc.RpcError) as e: 

343 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),)) 

344 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

345 assert e.value.details() == "Unauthorized" 

346 

347 # can auth with grpc helper (they do the same as above) 

348 comp_creds = grpc.composite_channel_credentials( 

349 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key) 

350 ) 

351 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc: 

352 res3 = call_rpc(empty_pb2.Empty()) 

353 assert res3.username == user.username 

354 

355 # can't auth with both 

356 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

357 with pytest.raises(grpc.RpcError) as e: 

358 call_rpc( 

359 empty_pb2.Empty(), 

360 metadata=( 

361 ("cookie", f"couchers-sesh={token}"), 

362 ("authorization", f"Bearer {api_key}"), 

363 ), 

364 ) 

365 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

366 assert e.value.details() == 'Both "cookie" and "authorization" in request' 

367 

368 # malformed bearer 

369 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

370 with pytest.raises(grpc.RpcError) as e: 

371 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),)) 

372 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

373 assert e.value.details() == "Unauthorized" 

374 

375 

376def test_tracing_interceptor_auth_cookies(db): 

377 user, token = generate_user() 

378 

379 account = Account() 

380 

381 rpc_def = { 

382 "rpc": account.GetAccountInfo, 

383 "service_name": "org.couchers.api.account.Account", 

384 "method_name": "GetAccountInfo", 

385 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

386 "request_type": empty_pb2.Empty, 

387 "response_type": account_pb2.GetAccountInfoRes, 

388 } 

389 

390 # with cookies 

391 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

392 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

393 assert res1.username == user.username 

394 

395 with session_scope() as session: 

396 trace = session.execute(select(APICall)).scalar_one() 

397 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

398 assert not trace.status_code 

399 assert trace.user_id == user.id 

400 assert not trace.is_api_key 

401 assert len(trace.request) == 0 

402 assert not trace.traceback 

403 

404 

405def test_tracing_interceptor_auth_api_key(db): 

406 super_user, super_token = generate_user(is_superuser=True) 

407 user, token = generate_user() 

408 

409 with real_admin_session(super_token) as api: 

410 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

411 

412 with session_scope() as session: 

413 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

414 api_key = api_session.token 

415 

416 account = Account() 

417 

418 rpc_def = { 

419 "rpc": account.GetAccountInfo, 

420 "service_name": "org.couchers.api.account.Account", 

421 "method_name": "GetAccountInfo", 

422 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor(), SessionInterceptor()], 

423 "request_type": empty_pb2.Empty, 

424 "response_type": account_pb2.GetAccountInfoRes, 

425 } 

426 

427 # with api key 

428 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

429 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

430 assert res1.username == user.username 

431 

432 with session_scope() as session: 

433 trace = session.execute(select(APICall)).scalar_one() 

434 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

435 assert not trace.status_code 

436 assert trace.user_id == user.id 

437 assert trace.is_api_key 

438 assert len(trace.request) == 0 

439 assert not trace.traceback 

440 

441 

442def test_auth_levels(db): 

443 def TestRpc(request, context): 

444 return empty_pb2.Empty() 

445 

446 def gen_args(service, method): 

447 return { 

448 "rpc": TestRpc, 

449 "service_name": service, 

450 "method_name": method, 

451 "interceptors": [AuthValidatorInterceptor()], 

452 "request_type": empty_pb2.Empty, 

453 "response_type": empty_pb2.Empty, 

454 } 

455 

456 # superuser 

457 _, super_token = generate_user(is_superuser=True) 

458 # normal user 

459 _, normal_token = generate_user() 

460 # jailed user 

461 _, jailed_token = generate_user(accepted_tos=0) 

462 # open user 

463 open_token = "" 

464 

465 # pick some rpcs here with the right auth levels 

466 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService") 

467 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo") 

468 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo") 

469 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails") 

470 

471 # pairs to check 

472 checks = [ 

473 # name, args, token, works?, code, message 

474 # open token only works on open servicers 

475 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

476 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

477 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

478 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

479 # jailed works on jailed and open 

480 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

481 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

482 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"), 

483 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

484 # normal works on all but admin 

485 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

486 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

487 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

488 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

489 # superuser works on all 

490 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

491 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

492 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

493 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

494 ] 

495 

496 for name, token, args, should_work, code, message in checks: 

497 print(f"Testing (token x args) = ({name}), {should_work=}") 

498 metadata = (("cookie", f"couchers-sesh={token}"),) 

499 with interceptor_dummy_api(**args) as call_rpc: 

500 if should_work: 

501 call_rpc(empty_pb2.Empty(), metadata=metadata) 

502 else: 

503 with pytest.raises(grpc.RpcError) as e: 

504 call_rpc(empty_pb2.Empty(), metadata=metadata) 

505 assert e.value.code() == code 

506 assert e.value.details() == message 

507 

508 # a non-existent RPC 

509 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing") 

510 

511 with interceptor_dummy_api(**nonexistent) as call_rpc: 

512 with pytest.raises(Exception) as e: 

513 call_rpc(empty_pb2.Empty()) 

514 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED 

515 assert e.value.details() == "API call does not exist. Please refresh and try again." 

516 

517 # an RPC without a service level 

518 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation") 

519 

520 with interceptor_dummy_api(**invalid_args) as call_rpc: 

521 with pytest.raises(Exception) as e: 

522 call_rpc(empty_pb2.Empty()) 

523 assert e.value.code() == grpc.StatusCode.INTERNAL 

524 assert e.value.details() == "Internal authentication error."