Coverage for src/couchers/rate_limits/definitions.py: 96%
23 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-07-12 05:54 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-07-12 05:54 +0000
1"""Rate limit definitions:
2In order to add a new rate limit definition, extend RateLimitAction and RATE_LIMIT_DEFINITIONS and call
3rate_limits.check.process_rate_limits_and_check_abort in the relevant endpoint.
4"""
6from dataclasses import dataclass
7from datetime import timedelta
8from typing import TYPE_CHECKING, Callable
10from sqlalchemy import func, select
12from couchers.models import (
13 Conversation,
14 FriendRelationship,
15 GroupChat,
16 GroupChatSubscription,
17 HostRequest,
18 RateLimitAction,
19 User,
20)
21from couchers.utils import now
23if TYPE_CHECKING:
24 from sqlalchemy.orm import Session
27@dataclass
28class RateLimitDefinition:
29 warning_limit: int
30 hard_limit: int
31 count_actions_query: Callable[["Session", int], int]
32 mod_email_information_query: Callable[["Session", int], list[dict]]
35RATE_LIMIT_INTERVAL = timedelta(hours=24)
36RATE_LIMIT_INTERVAL_STRING = "24 hours"
39def _get_user_host_requests_in_past_time_interval(session, user_id) -> list[dict]:
40 return (
41 session.execute(
42 select(
43 Conversation.created.label("created"),
44 HostRequest.host_user_id.label("host id"),
45 User.username.label("host username"),
46 User.city.label("host city"),
47 )
48 .join(Conversation, HostRequest.conversation_id == Conversation.id)
49 .join(User, HostRequest.host_user_id == User.id)
50 .where(HostRequest.surfer_user_id == user_id)
51 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL)
52 )
53 .mappings()
54 .all()
55 )
58def _get_user_friend_requests_in_past_time_interval(session, user_id) -> list[dict]:
59 return (
60 session.execute(
61 select(
62 FriendRelationship.time_sent,
63 User.id.label("to_user (ID)"),
64 User.username.label("to_user (username)"),
65 FriendRelationship.status,
66 )
67 .join(User, FriendRelationship.to_user_id == User.id)
68 .where(FriendRelationship.from_user_id == user_id)
69 .where(FriendRelationship.time_sent >= now() - RATE_LIMIT_INTERVAL)
70 )
71 .mappings()
72 .all()
73 )
76def _get_user_initiated_chats_in_past_time_interval(session, user_id) -> list[dict]:
77 return (
78 session.execute(
79 select(
80 Conversation.id,
81 Conversation.created,
82 GroupChat.title,
83 GroupChat.is_dm,
84 func.array_agg(User.username).label("participants"),
85 )
86 .join(Conversation, GroupChat.conversation_id == Conversation.id)
87 .join(GroupChatSubscription, Conversation.id == GroupChatSubscription.group_chat_id)
88 .join(User, GroupChatSubscription.user_id == User.id)
89 .where(GroupChat.creator_id == user_id)
90 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL)
91 .where(GroupChatSubscription.left == None)
92 .group_by(Conversation.id, Conversation.created, GroupChat.title, GroupChat.is_dm)
93 )
94 .mappings()
95 .all()
96 )
99RATE_LIMIT_DEFINITIONS = {
100 RateLimitAction.host_request: RateLimitDefinition(
101 warning_limit=20,
102 hard_limit=80,
103 count_actions_query=lambda session, user_id: session.execute(
104 select(func.count())
105 .select_from(HostRequest)
106 .join(Conversation, HostRequest.conversation_id == Conversation.id)
107 .where(HostRequest.surfer_user_id == user_id)
108 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL)
109 ).scalar_one(),
110 mod_email_information_query=_get_user_host_requests_in_past_time_interval,
111 ),
112 RateLimitAction.friend_request: RateLimitDefinition(
113 warning_limit=10,
114 hard_limit=40,
115 count_actions_query=lambda session, user_id: session.execute(
116 select(func.count())
117 .select_from(FriendRelationship)
118 .where(FriendRelationship.from_user_id == user_id)
119 .where(FriendRelationship.time_sent >= now() - RATE_LIMIT_INTERVAL)
120 ).scalar_one(),
121 mod_email_information_query=_get_user_friend_requests_in_past_time_interval,
122 ),
123 RateLimitAction.chat_initiation: RateLimitDefinition(
124 warning_limit=15,
125 hard_limit=150,
126 count_actions_query=lambda session, user_id: session.execute(
127 select(func.count())
128 .select_from(GroupChat)
129 .join(Conversation, GroupChat.conversation_id == Conversation.id)
130 .where(GroupChat.creator_id == user_id)
131 .where(Conversation.created >= now() - RATE_LIMIT_INTERVAL)
132 ).scalar_one(),
133 mod_email_information_query=_get_user_initiated_chats_in_past_time_interval,
134 ),
135}