Coverage for src/tests/test_interceptors.py: 100%

271 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-07-20 21:46 +0000

1from concurrent import futures 

2from contextlib import contextmanager 

3 

4import grpc 

5import pytest 

6from google.protobuf import empty_pb2 

7 

8from couchers import errors 

9from couchers.crypto import random_hex 

10from couchers.db import session_scope 

11from couchers.interceptors import AuthValidatorInterceptor, ErrorSanitizationInterceptor, TracingInterceptor 

12from couchers.metrics import CODE_LABEL, EXCEPTION_LABEL, METHOD_LABEL, servicer_duration_histogram 

13from couchers.models import APICall, UserSession 

14from couchers.servicers.account import Account 

15from couchers.sql import couchers_select as select 

16from proto import account_pb2, admin_pb2, auth_pb2 

17from tests.test_fixtures import db, generate_user, real_admin_session, testconfig # noqa 

18 

19 

20@pytest.fixture(autouse=True) 

21def _(testconfig): 

22 pass 

23 

24 

25@contextmanager 

26def interceptor_dummy_api( 

27 rpc, 

28 interceptors, 

29 service_name="testing.Test", 

30 method_name="TestRpc", 

31 request_type=empty_pb2.Empty, 

32 response_type=empty_pb2.Empty, 

33 creds=None, 

34): 

35 with futures.ThreadPoolExecutor(1) as executor: 

36 server = grpc.server(executor, interceptors=interceptors) 

37 port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) 

38 

39 # manually add the handler 

40 rpc_method_handlers = { 

41 method_name: grpc.unary_unary_rpc_method_handler( 

42 rpc, 

43 request_deserializer=request_type.FromString, 

44 response_serializer=response_type.SerializeToString, 

45 ) 

46 } 

47 generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) 

48 server.add_generic_rpc_handlers((generic_handler,)) 

49 server.start() 

50 

51 try: 

52 with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: 

53 call_rpc = channel.unary_unary( 

54 f"/{service_name}/{method_name}", 

55 request_serializer=request_type.SerializeToString, 

56 response_deserializer=response_type.FromString, 

57 ) 

58 yield call_rpc 

59 finally: 

60 server.stop(None).wait() 

61 

62 

63def _check_histogram_labels(method, exception, code, count): 

64 metrics = servicer_duration_histogram.collect() 

65 servicer_histogram = [m for m in metrics if m.name == "servicer_duration"][0] 

66 histogram_count = [ 

67 s 

68 for s in servicer_histogram.samples 

69 if s.name == "servicer_duration_count" 

70 and s.labels[METHOD_LABEL] == method 

71 and s.labels[EXCEPTION_LABEL] == exception 

72 and s.labels[CODE_LABEL] == code 

73 ][0] 

74 assert histogram_count.value == count 

75 servicer_duration_histogram.clear() 

76 

77 

78def test_logging_interceptor_ok(): 

79 def TestRpc(request, context): 

80 return empty_pb2.Empty() 

81 

82 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

83 call_rpc(empty_pb2.Empty()) 

84 

85 

86def test_logging_interceptor_all_ignored(): 

87 # error codes that should not be touched by the interceptor 

88 pass_through_status_codes = [ 

89 # we can't abort with OK 

90 # grpc.StatusCode.OK, 

91 grpc.StatusCode.CANCELLED, 

92 grpc.StatusCode.UNKNOWN, 

93 grpc.StatusCode.INVALID_ARGUMENT, 

94 grpc.StatusCode.DEADLINE_EXCEEDED, 

95 grpc.StatusCode.NOT_FOUND, 

96 grpc.StatusCode.ALREADY_EXISTS, 

97 grpc.StatusCode.PERMISSION_DENIED, 

98 grpc.StatusCode.UNAUTHENTICATED, 

99 grpc.StatusCode.RESOURCE_EXHAUSTED, 

100 grpc.StatusCode.FAILED_PRECONDITION, 

101 grpc.StatusCode.ABORTED, 

102 grpc.StatusCode.OUT_OF_RANGE, 

103 grpc.StatusCode.UNIMPLEMENTED, 

104 grpc.StatusCode.INTERNAL, 

105 grpc.StatusCode.UNAVAILABLE, 

106 grpc.StatusCode.DATA_LOSS, 

107 ] 

108 

109 for status_code in pass_through_status_codes: 

110 message = random_hex() 

111 

112 def TestRpc(request, context): 

113 context.abort(status_code, message) # noqa: B023 

114 

115 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

116 with pytest.raises(grpc.RpcError) as e: 

117 call_rpc(empty_pb2.Empty()) 

118 assert e.value.code() == status_code 

119 assert e.value.details() == message 

120 

121 

122def test_logging_interceptor_assertion(): 

123 def TestRpc(request, context): 

124 raise AssertionError() 

125 

126 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

127 with pytest.raises(grpc.RpcError) as e: 

128 call_rpc(empty_pb2.Empty()) 

129 assert e.value.code() == grpc.StatusCode.INTERNAL 

130 assert e.value.details() == errors.UNKNOWN_ERROR 

131 

132 

133def test_logging_interceptor_div0(): 

134 def TestRpc(request, context): 

135 1 / 0 # noqa: B018 

136 

137 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

138 with pytest.raises(grpc.RpcError) as e: 

139 call_rpc(empty_pb2.Empty()) 

140 assert e.value.code() == grpc.StatusCode.INTERNAL 

141 assert e.value.details() == errors.UNKNOWN_ERROR 

142 

143 

144def test_logging_interceptor_raise(): 

145 def TestRpc(request, context): 

146 raise Exception() 

147 

148 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

149 with pytest.raises(grpc.RpcError) as e: 

150 call_rpc(empty_pb2.Empty()) 

151 assert e.value.code() == grpc.StatusCode.INTERNAL 

152 assert e.value.details() == errors.UNKNOWN_ERROR 

153 

154 

155def test_logging_interceptor_raise_custom(): 

156 class _TestingException(Exception): 

157 pass 

158 

159 def TestRpc(request, context): 

160 raise _TestingException("This is a custom exception") 

161 

162 with interceptor_dummy_api(TestRpc, interceptors=[ErrorSanitizationInterceptor()]) as call_rpc: 

163 with pytest.raises(grpc.RpcError) as e: 

164 call_rpc(empty_pb2.Empty()) 

165 assert e.value.code() == grpc.StatusCode.INTERNAL 

166 assert e.value.details() == errors.UNKNOWN_ERROR 

167 

168 

169def test_tracing_interceptor_ok_open(db): 

170 def TestRpc(request, context): 

171 return empty_pb2.Empty() 

172 

173 with interceptor_dummy_api(TestRpc, interceptors=[TracingInterceptor()]) as call_rpc: 

174 call_rpc(empty_pb2.Empty()) 

175 

176 with session_scope() as session: 

177 trace = session.execute(select(APICall)).scalar_one() 

178 assert trace.method == "/testing.Test/TestRpc" 

179 assert not trace.status_code 

180 assert not trace.user_id 

181 assert len(trace.request) == 0 

182 assert len(trace.response) == 0 

183 assert not trace.traceback 

184 

185 _check_histogram_labels("/testing.Test/TestRpc", "", "", 1) 

186 

187 

188def test_tracing_interceptor_sensitive(db): 

189 def TestRpc(request, context): 

190 return auth_pb2.AuthReq(user="this is not secret", password="this is secret") 

191 

192 with interceptor_dummy_api( 

193 TestRpc, 

194 interceptors=[TracingInterceptor()], 

195 request_type=auth_pb2.SignupAccount, 

196 response_type=auth_pb2.AuthReq, 

197 ) as call_rpc: 

198 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

199 

200 with session_scope() as session: 

201 trace = session.execute(select(APICall)).scalar_one() 

202 assert trace.method == "/testing.Test/TestRpc" 

203 assert not trace.status_code 

204 assert not trace.user_id 

205 assert not trace.traceback 

206 req = auth_pb2.SignupAccount.FromString(trace.request) 

207 assert not req.password 

208 assert req.username == "not removed" 

209 res = auth_pb2.AuthReq.FromString(trace.response) 

210 assert res.user == "this is not secret" 

211 assert not res.password 

212 

213 _check_histogram_labels("/testing.Test/TestRpc", "", "", 1) 

214 

215 

216def test_tracing_interceptor_exception(db): 

217 def TestRpc(request, context): 

218 raise Exception("Some error message") 

219 

220 with interceptor_dummy_api( 

221 TestRpc, 

222 interceptors=[TracingInterceptor()], 

223 request_type=auth_pb2.SignupAccount, 

224 response_type=auth_pb2.AuthReq, 

225 ) as call_rpc: 

226 with pytest.raises(Exception, match="Some error message"): 

227 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

228 

229 with session_scope() as session: 

230 trace = session.execute(select(APICall)).scalar_one() 

231 assert trace.method == "/testing.Test/TestRpc" 

232 assert not trace.status_code 

233 assert not trace.user_id 

234 assert "Some error message" in trace.traceback 

235 req = auth_pb2.SignupAccount.FromString(trace.request) 

236 assert not req.password 

237 assert req.username == "not removed" 

238 assert not trace.response 

239 

240 _check_histogram_labels("/testing.Test/TestRpc", "Exception", "", 1) 

241 

242 

243def test_tracing_interceptor_abort(db): 

244 def TestRpc(request, context): 

245 context.abort(grpc.StatusCode.FAILED_PRECONDITION, "now a grpc abort") 

246 

247 with interceptor_dummy_api( 

248 TestRpc, 

249 interceptors=[TracingInterceptor()], 

250 request_type=auth_pb2.SignupAccount, 

251 response_type=auth_pb2.AuthReq, 

252 ) as call_rpc: 

253 with pytest.raises(Exception, match="now a grpc abort"): 

254 call_rpc(auth_pb2.SignupAccount(password="should be removed", username="not removed")) 

255 

256 with session_scope() as session: 

257 trace = session.execute(select(APICall)).scalar_one() 

258 assert trace.method == "/testing.Test/TestRpc" 

259 assert trace.status_code == "FAILED_PRECONDITION" 

260 assert not trace.user_id 

261 assert "now a grpc abort" in trace.traceback 

262 req = auth_pb2.SignupAccount.FromString(trace.request) 

263 assert not req.password 

264 assert req.username == "not removed" 

265 assert not trace.response 

266 

267 _check_histogram_labels("/testing.Test/TestRpc", "Exception", "FAILED_PRECONDITION", 1) 

268 

269 

270def test_auth_interceptor(db): 

271 super_user, super_token = generate_user(is_superuser=True) 

272 user, token = generate_user() 

273 

274 with real_admin_session(super_token) as api: 

275 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

276 

277 with session_scope() as session: 

278 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

279 api_key = api_session.token 

280 

281 account = Account() 

282 

283 rpc_def = { 

284 "rpc": account.GetAccountInfo, 

285 "service_name": "org.couchers.api.account.Account", 

286 "method_name": "GetAccountInfo", 

287 "interceptors": [AuthValidatorInterceptor()], 

288 "request_type": empty_pb2.Empty, 

289 "response_type": account_pb2.GetAccountInfoRes, 

290 } 

291 

292 # no creds, no go for secure APIs 

293 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

294 with pytest.raises(grpc.RpcError) as e: 

295 call_rpc(empty_pb2.Empty()) 

296 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

297 assert e.value.details() == "Unauthorized" 

298 

299 # can auth with cookie 

300 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

301 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

302 assert res1.username == user.username 

303 

304 # can't auth with wrong cookie 

305 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

306 with pytest.raises(grpc.RpcError) as e: 

307 call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={random_hex(32)}"),)) 

308 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

309 assert e.value.details() == "Unauthorized" 

310 

311 # can auth with api key 

312 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

313 res2 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

314 assert res2.username == user.username 

315 

316 # can't auth with wrong api key 

317 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

318 with pytest.raises(grpc.RpcError) as e: 

319 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {random_hex(32)}"),)) 

320 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

321 assert e.value.details() == "Unauthorized" 

322 

323 # can auth with grpc helper (they do the same as above) 

324 comp_creds = grpc.composite_channel_credentials( 

325 grpc.local_channel_credentials(), grpc.access_token_call_credentials(api_key) 

326 ) 

327 with interceptor_dummy_api(**rpc_def, creds=comp_creds) as call_rpc: 

328 res3 = call_rpc(empty_pb2.Empty()) 

329 assert res3.username == user.username 

330 

331 # can't auth with both 

332 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

333 with pytest.raises(grpc.RpcError) as e: 

334 call_rpc( 

335 empty_pb2.Empty(), 

336 metadata=( 

337 ("cookie", f"couchers-sesh={token}"), 

338 ("authorization", f"Bearer {api_key}"), 

339 ), 

340 ) 

341 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

342 assert e.value.details() == 'Both "cookie" and "authorization" in request' 

343 

344 # malformed bearer 

345 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

346 with pytest.raises(grpc.RpcError) as e: 

347 call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"bearer {api_key}"),)) 

348 assert e.value.code() == grpc.StatusCode.UNAUTHENTICATED 

349 assert e.value.details() == "Unauthorized" 

350 

351 

352def test_tracing_interceptor_auth_cookies(db): 

353 user, token = generate_user() 

354 

355 account = Account() 

356 

357 rpc_def = { 

358 "rpc": account.GetAccountInfo, 

359 "service_name": "org.couchers.api.account.Account", 

360 "method_name": "GetAccountInfo", 

361 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()], 

362 "request_type": empty_pb2.Empty, 

363 "response_type": account_pb2.GetAccountInfoRes, 

364 } 

365 

366 # with cookies 

367 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

368 res1 = call_rpc(empty_pb2.Empty(), metadata=(("cookie", f"couchers-sesh={token}"),)) 

369 assert res1.username == user.username 

370 

371 with session_scope() as session: 

372 trace = session.execute(select(APICall)).scalar_one() 

373 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

374 assert not trace.status_code 

375 assert trace.user_id == user.id 

376 assert not trace.is_api_key 

377 assert len(trace.request) == 0 

378 assert not trace.traceback 

379 

380 

381def test_tracing_interceptor_auth_api_key(db): 

382 super_user, super_token = generate_user(is_superuser=True) 

383 user, token = generate_user() 

384 

385 with real_admin_session(super_token) as api: 

386 api.CreateApiKey(admin_pb2.CreateApiKeyReq(user=user.username)) 

387 

388 with session_scope() as session: 

389 api_session = session.execute(select(UserSession).where(UserSession.is_api_key == True)).scalar_one() 

390 api_key = api_session.token 

391 

392 account = Account() 

393 

394 rpc_def = { 

395 "rpc": account.GetAccountInfo, 

396 "service_name": "org.couchers.api.account.Account", 

397 "method_name": "GetAccountInfo", 

398 "interceptors": [TracingInterceptor(), AuthValidatorInterceptor()], 

399 "request_type": empty_pb2.Empty, 

400 "response_type": account_pb2.GetAccountInfoRes, 

401 } 

402 

403 # with api key 

404 with interceptor_dummy_api(**rpc_def, creds=grpc.local_channel_credentials()) as call_rpc: 

405 res1 = call_rpc(empty_pb2.Empty(), metadata=(("authorization", f"Bearer {api_key}"),)) 

406 assert res1.username == user.username 

407 

408 with session_scope() as session: 

409 trace = session.execute(select(APICall)).scalar_one() 

410 assert trace.method == "/org.couchers.api.account.Account/GetAccountInfo" 

411 assert not trace.status_code 

412 assert trace.user_id == user.id 

413 assert trace.is_api_key 

414 assert len(trace.request) == 0 

415 assert not trace.traceback 

416 

417 

418def test_auth_levels(db): 

419 def TestRpc(request, context): 

420 return empty_pb2.Empty() 

421 

422 def gen_args(service, method): 

423 return { 

424 "rpc": TestRpc, 

425 "service_name": service, 

426 "method_name": method, 

427 "interceptors": [AuthValidatorInterceptor()], 

428 "request_type": empty_pb2.Empty, 

429 "response_type": empty_pb2.Empty, 

430 } 

431 

432 # superuser 

433 _, super_token = generate_user(is_superuser=True) 

434 # normal user 

435 _, normal_token = generate_user() 

436 # jailed user 

437 _, jailed_token = generate_user(accepted_tos=0) 

438 # open user 

439 open_token = "" 

440 

441 # pick some rpcs here with the right auth levels 

442 open_args = gen_args("org.couchers.resources.Resources", "GetTermsOfService") 

443 jailed_args = gen_args("org.couchers.jail.Jail", "JailInfo") 

444 secure_args = gen_args("org.couchers.api.account.Account", "GetAccountInfo") 

445 admin_args = gen_args("org.couchers.admin.Admin", "GetUserDetails") 

446 

447 # pairs to check 

448 checks = [ 

449 # name, args, token, works?, code, message 

450 # open token only works on open servicers 

451 ("open x open", open_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

452 ("open x jailed", open_token, jailed_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

453 ("open x secure", open_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

454 ("open x admin", open_token, admin_args, False, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

455 # jailed works on jailed and open 

456 ("jailed x open", jailed_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

457 ("jailed x jailed", jailed_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

458 ("jailed x secure", jailed_token, secure_args, False, grpc.StatusCode.UNAUTHENTICATED, "Permission denied"), 

459 ("jailed x admin", jailed_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

460 # normal works on all but admin 

461 ("normal x open", normal_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

462 ("normal x jailed", normal_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

463 ("normal x secure", normal_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

464 ("normal x admin", normal_token, admin_args, False, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

465 # superuser works on all 

466 ("super x open", super_token, open_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

467 ("super x jailed", super_token, jailed_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

468 ("super x secure", super_token, secure_args, True, grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"), 

469 ("super x admin", super_token, admin_args, True, grpc.StatusCode.PERMISSION_DENIED, "Permission denied"), 

470 ] 

471 

472 for name, token, args, should_work, code, message in checks: 

473 print(f"Testing (token x args) = ({name}), {should_work=}") 

474 metadata = (("cookie", f"couchers-sesh={token}"),) 

475 with interceptor_dummy_api(**args) as call_rpc: 

476 if should_work: 

477 call_rpc(empty_pb2.Empty(), metadata=metadata) 

478 else: 

479 with pytest.raises(grpc.RpcError) as e: 

480 call_rpc(empty_pb2.Empty(), metadata=metadata) 

481 assert e.value.code() == code 

482 assert e.value.details() == message 

483 

484 # a non-existent RPC 

485 nonexistent = gen_args("org.couchers.nonexistent.NA", "GetNothing") 

486 

487 with interceptor_dummy_api(**nonexistent) as call_rpc: 

488 with pytest.raises(Exception) as e: 

489 call_rpc(empty_pb2.Empty()) 

490 assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED 

491 assert e.value.details() == "API call does not exist. Please refresh and try again." 

492 

493 # an RPC without a service level 

494 invalid_args = gen_args("org.couchers.media.Media", "UploadConfirmation") 

495 

496 with interceptor_dummy_api(**invalid_args) as call_rpc: 

497 with pytest.raises(Exception) as e: 

498 call_rpc(empty_pb2.Empty()) 

499 assert e.value.code() == grpc.StatusCode.INTERNAL 

500 assert e.value.details() == "Internal authentication error."