Coverage for src/couchers/rate_limits/definitions.py: 96%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-07-12 05:54 +0000

1"""Rate limit definitions: 

2In order to add a new rate limit definition, extend RateLimitAction and RATE_LIMIT_DEFINITIONS and call 

3rate_limits.check.process_rate_limits_and_check_abort in the relevant endpoint. 

4""" 

5 

6from dataclasses import dataclass 

7from datetime import timedelta 

8from typing import TYPE_CHECKING, Callable 

9 

10from sqlalchemy import func, select 

11 

12from couchers.models import ( 

13 Conversation, 

14 FriendRelationship, 

15 GroupChat, 

16 GroupChatSubscription, 

17 HostRequest, 

18 RateLimitAction, 

19 User, 

20) 

21from couchers.utils import now 

22 

23if TYPE_CHECKING: 

24 from sqlalchemy.orm import Session 

25 

26 

27@dataclass 

28class RateLimitDefinition: 

29 warning_limit: int 

30 hard_limit: int 

31 count_actions_query: Callable[["Session", int], int] 

32 mod_email_information_query: Callable[["Session", int], list[dict]] 

33 

34 

35RATE_LIMIT_INTERVAL = timedelta(hours=24) 

36RATE_LIMIT_INTERVAL_STRING = "24 hours" 

37 

38 

39def _get_user_host_requests_in_past_time_interval(session, user_id) -> list[dict]: 

40 return ( 

41 session.execute( 

42 select( 

43 Conversation.created.label("created"), 

44 HostRequest.host_user_id.label("host id"), 

45 User.username.label("host username"), 

46 User.city.label("host city"), 

47 ) 

48 .join(Conversation, HostRequest.conversation_id == Conversation.id) 

49 .join(User, HostRequest.host_user_id == User.id) 

50 .where(HostRequest.surfer_user_id == user_id) 

51 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL) 

52 ) 

53 .mappings() 

54 .all() 

55 ) 

56 

57 

58def _get_user_friend_requests_in_past_time_interval(session, user_id) -> list[dict]: 

59 return ( 

60 session.execute( 

61 select( 

62 FriendRelationship.time_sent, 

63 User.id.label("to_user (ID)"), 

64 User.username.label("to_user (username)"), 

65 FriendRelationship.status, 

66 ) 

67 .join(User, FriendRelationship.to_user_id == User.id) 

68 .where(FriendRelationship.from_user_id == user_id) 

69 .where(FriendRelationship.time_sent >= now() - RATE_LIMIT_INTERVAL) 

70 ) 

71 .mappings() 

72 .all() 

73 ) 

74 

75 

76def _get_user_initiated_chats_in_past_time_interval(session, user_id) -> list[dict]: 

77 return ( 

78 session.execute( 

79 select( 

80 Conversation.id, 

81 Conversation.created, 

82 GroupChat.title, 

83 GroupChat.is_dm, 

84 func.array_agg(User.username).label("participants"), 

85 ) 

86 .join(Conversation, GroupChat.conversation_id == Conversation.id) 

87 .join(GroupChatSubscription, Conversation.id == GroupChatSubscription.group_chat_id) 

88 .join(User, GroupChatSubscription.user_id == User.id) 

89 .where(GroupChat.creator_id == user_id) 

90 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL) 

91 .where(GroupChatSubscription.left == None) 

92 .group_by(Conversation.id, Conversation.created, GroupChat.title, GroupChat.is_dm) 

93 ) 

94 .mappings() 

95 .all() 

96 ) 

97 

98 

99RATE_LIMIT_DEFINITIONS = { 

100 RateLimitAction.host_request: RateLimitDefinition( 

101 warning_limit=20, 

102 hard_limit=80, 

103 count_actions_query=lambda session, user_id: session.execute( 

104 select(func.count()) 

105 .select_from(HostRequest) 

106 .join(Conversation, HostRequest.conversation_id == Conversation.id) 

107 .where(HostRequest.surfer_user_id == user_id) 

108 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL) 

109 ).scalar_one(), 

110 mod_email_information_query=_get_user_host_requests_in_past_time_interval, 

111 ), 

112 RateLimitAction.friend_request: RateLimitDefinition( 

113 warning_limit=10, 

114 hard_limit=40, 

115 count_actions_query=lambda session, user_id: session.execute( 

116 select(func.count()) 

117 .select_from(FriendRelationship) 

118 .where(FriendRelationship.from_user_id == user_id) 

119 .where(FriendRelationship.time_sent >= now() - RATE_LIMIT_INTERVAL) 

120 ).scalar_one(), 

121 mod_email_information_query=_get_user_friend_requests_in_past_time_interval, 

122 ), 

123 RateLimitAction.chat_initiation: RateLimitDefinition( 

124 warning_limit=15, 

125 hard_limit=150, 

126 count_actions_query=lambda session, user_id: session.execute( 

127 select(func.count()) 

128 .select_from(GroupChat) 

129 .join(Conversation, GroupChat.conversation_id == Conversation.id) 

130 .where(GroupChat.creator_id == user_id) 

131 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL) 

132 ).scalar_one(), 

133 mod_email_information_query=_get_user_initiated_chats_in_past_time_interval, 

134 ), 

135}