Coverage for src/tests/test_threads.py: 100%

86 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-12-20 18:03 +0000

1import string 

2 

3import grpc 

4import pytest 

5 

6from couchers import errors 

7from couchers.db import session_scope 

8from couchers.models import Thread 

9from couchers.servicers.threads import pack_thread_id 

10from proto import threads_pb2 

11from tests.test_fixtures import db, generate_user, testconfig, threads_session # noqa 

12 

13 

14@pytest.fixture(autouse=True) 

15def _(testconfig): 

16 pass 

17 

18 

19def test_threads_basic(db): 

20 user1, token1 = generate_user() 

21 

22 # Create a dummy Thread (should be replaced by pages later on) 

23 with session_scope() as session: 

24 dummy_thread = Thread() 

25 session.add(dummy_thread) 

26 session.flush() 

27 PARENT_THREAD_ID = pack_thread_id(database_id=dummy_thread.id, depth=0) 

28 

29 with threads_session(token1) as api: 

30 bat_id = api.PostReply(threads_pb2.PostReplyReq(thread_id=PARENT_THREAD_ID, content="bat")).thread_id 

31 

32 cat_id = api.PostReply(threads_pb2.PostReplyReq(thread_id=PARENT_THREAD_ID, content="cat")).thread_id 

33 

34 dog_id = api.PostReply(threads_pb2.PostReplyReq(thread_id=PARENT_THREAD_ID, content="dog")).thread_id 

35 

36 dogs = [ 

37 api.PostReply(threads_pb2.PostReplyReq(thread_id=dog_id, content=animal)).thread_id 

38 for animal in ["hyena", "wolf", "prariewolf"] 

39 ] 

40 cats = [ 

41 api.PostReply(threads_pb2.PostReplyReq(thread_id=cat_id, content=animal)).thread_id 

42 for animal in ["cheetah", "lynx", "panther"] 

43 ] 

44 

45 # Make some queries 

46 ret = api.GetThread(threads_pb2.GetThreadReq(thread_id=PARENT_THREAD_ID)) 

47 assert len(ret.replies) == 3 

48 assert ret.next_page_token == "" 

49 assert ret.replies[0].thread_id == dog_id 

50 assert ret.replies[0].content == "dog" 

51 assert ret.replies[0].author_user_id == user1.id 

52 assert ret.replies[0].num_replies == 3 

53 

54 assert ret.replies[1].thread_id == cat_id 

55 assert ret.replies[1].content == "cat" 

56 assert ret.replies[1].author_user_id == user1.id 

57 assert ret.replies[1].num_replies == 3 

58 

59 assert ret.replies[2].thread_id == bat_id 

60 assert ret.replies[2].content == "bat" 

61 assert ret.replies[2].author_user_id == user1.id 

62 assert ret.replies[2].num_replies == 0 

63 

64 ret = api.GetThread(threads_pb2.GetThreadReq(thread_id=cat_id)) 

65 assert len(ret.replies) == 3 

66 assert ret.next_page_token == "" 

67 assert [reply.thread_id for reply in ret.replies] == cats[::-1] 

68 

69 ret = api.GetThread(threads_pb2.GetThreadReq(thread_id=dog_id)) 

70 assert len(ret.replies) == 3 

71 assert ret.next_page_token == "" 

72 assert [reply.thread_id for reply in ret.replies] == dogs[::-1] 

73 

74 

75def test_threads_errors(db): 

76 user1, token1 = generate_user() 

77 with threads_session(token1) as api: 

78 # request non-existing comment 

79 with pytest.raises(grpc.RpcError) as e: 

80 api.GetThread(threads_pb2.GetThreadReq(thread_id=11)) 

81 assert e.value.code() == grpc.StatusCode.NOT_FOUND 

82 assert e.value.details() == errors.THREAD_NOT_FOUND 

83 

84 # request non-existing depth digit 

85 with pytest.raises(grpc.RpcError) as e: 

86 api.GetThread(threads_pb2.GetThreadReq(thread_id=19)) 

87 assert e.value.code() == grpc.StatusCode.NOT_FOUND 

88 assert e.value.details() == errors.THREAD_NOT_FOUND 

89 

90 # post on non-existing comment 

91 with pytest.raises(grpc.RpcError) as e: 

92 api.PostReply(threads_pb2.PostReplyReq(thread_id=11, content="foo")) 

93 assert e.value.code() == grpc.StatusCode.NOT_FOUND 

94 assert e.value.details() == errors.THREAD_NOT_FOUND 

95 

96 # post on non-existing depth 

97 with pytest.raises(grpc.RpcError) as e: 

98 api.PostReply(threads_pb2.PostReplyReq(thread_id=19, content="foo")) 

99 assert e.value.code() == grpc.StatusCode.NOT_FOUND 

100 assert e.value.details() == errors.THREAD_NOT_FOUND 

101 

102 

103def pagination_test(api, parent_id): 

104 # Post some data 

105 for c in reversed(string.ascii_lowercase): 

106 api.PostReply(threads_pb2.PostReplyReq(thread_id=parent_id, content=c)) 

107 

108 # Get it with pagination 

109 token = "" 

110 import textwrap 

111 

112 for expected_page in textwrap.wrap(string.ascii_lowercase, 5): 

113 ret = api.GetThread(threads_pb2.GetThreadReq(thread_id=parent_id, page_size=5, page_token=token)) 

114 assert "".join(x.content for x in ret.replies) == expected_page 

115 token = ret.next_page_token 

116 

117 assert token == "" 

118 

119 return ret.replies[0].thread_id # to be used as a test one level deeper 

120 

121 

122def test_threads_pagination(db): 

123 user1, token1 = generate_user() 

124 

125 PARENT_THREAD_ID = 10 

126 

127 # Create a dummy Thread (should be replaced by pages later on) 

128 with session_scope() as session: 

129 session.add(Thread(id=1)) 

130 

131 with threads_session(token1) as api: 

132 comment_id = pagination_test(api, PARENT_THREAD_ID) 

133 pagination_test(api, comment_id)