Coverage for src/couchers/sql.py: 99%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-20 11:53 +0000

1from typing import TYPE_CHECKING, Any, Self 

2 

3from sqlalchemy import and_, false, or_ 

4from sqlalchemy.orm import InstrumentedAttribute, aliased 

5from sqlalchemy.sql import Select, exists, union 

6 

7from couchers.context import CouchersContext 

8from couchers.models import GroupChat, HostRequest, ModerationState, ModerationVisibility, SignupFlow, User, UserBlock 

9from couchers.utils import is_valid_email, is_valid_user_id, is_valid_username 

10 

11if TYPE_CHECKING: 

12 from typing import Protocol 

13 

14 from couchers.materialized_views import LiteUser 

15 

16 type _UserLike = type[User | LiteUser | SignupFlow] 

17 type _User = type[User | LiteUser] 

18 

19 class ModeratedContent(Protocol): 

20 moderation_state_id: InstrumentedAttribute[int] 

21 __moderation_author_column__: str 

22 

23 

24class CouchersSelect(Select[Any]): 

25 """ 

26 This method construct provided directly by the developers 

27 They intend to implement a better option in the near future 

28 See issue here: https://github.com/sqlalchemy/sqlalchemy/issues/6700 

29 """ 

30 

31 inherit_cache = True 

32 

33 def where_username_or_email(self, value: str, table: "_UserLike" = User) -> Self: 

34 if is_valid_username(value): 

35 return self.where(table.username == value) 

36 elif is_valid_email(value) and hasattr(table, "email"): 

37 return self.where(table.email == value) 

38 # no fields match, this will return no rows 

39 return self.where(false()) 

40 

41 def where_username_or_id(self, value: str, table: "_UserLike" = User) -> Self: 

42 if is_valid_username(value): 

43 return self.where(table.username == value) 

44 elif is_valid_user_id(value): 

45 return self.where(table.id == value) 

46 # no fields match, this will return no rows 

47 return self.where(false()) 

48 

49 def where_username_or_email_or_id(self, value: str) -> Self: 

50 # Should only be used for admin APIs, etc. 

51 if is_valid_username(value): 

52 return self.where(User.username == value) 

53 elif is_valid_email(value): 

54 return self.where(User.email == value) 

55 elif is_valid_user_id(value): 

56 return self.where(User.id == value) 

57 # no fields match, this will return no rows 

58 return self.where(false()) 

59 

60 def where_users_visible(self, context: CouchersContext, table: "_User" = User) -> Self: 

61 """ 

62 Filters out users that should not be visible: blocked, deleted, or banned 

63 

64 Filters the given table, assuming it's already joined/selected from 

65 """ 

66 hidden_users = _relevant_user_blocks(context.user_id) 

67 return self.where(table.is_visible).where(~table.id.in_(hidden_users)) 

68 

69 def where_users_column_visible(self, context: CouchersContext, column: InstrumentedAttribute[int]) -> Self: 

70 """ 

71 Filters the given column, not yet joined/selected from 

72 """ 

73 hidden_users = _relevant_user_blocks(context.user_id) 

74 aliased_user = aliased(User) 

75 return ( 

76 self.join(aliased_user, aliased_user.id == column) 

77 .where(aliased_user.is_visible) 

78 .where(~aliased_user.id.in_(hidden_users)) 

79 ) 

80 

81 def where_users_visible_to_each_other(self, user1: "_User", user2: "_User") -> Self: 

82 """ 

83 Filters to ensure two users are mutually visible to each other. 

84 

85 Checks that: 

86 - Both users are visible (not deleted/banned) 

87 - Neither user has blocked the other (bidirectional check) 

88 

89 Use this when both User tables are already joined/selected in the query. 

90 """ 

91 return ( 

92 self.where(user1.is_visible) 

93 .where(user2.is_visible) 

94 .where( 

95 ~exists( 

96 couchers_select(1) 

97 .select_from(UserBlock) 

98 .where( 

99 or_( 

100 and_(UserBlock.blocking_user_id == user1.id, UserBlock.blocked_user_id == user2.id), 

101 and_(UserBlock.blocking_user_id == user2.id, UserBlock.blocked_user_id == user1.id), 

102 ) 

103 ) 

104 ) 

105 ) 

106 ) 

107 

108 def where_user_columns_visible_to_each_other( 

109 self, column1: InstrumentedAttribute[int], column2: InstrumentedAttribute[int] 

110 ) -> Self: 

111 """ 

112 Filters to ensure two users are mutually visible to each other. 

113 

114 Checks that: 

115 - Both users are visible (not deleted/banned) 

116 - Neither user has blocked the other (bidirectional check) 

117 

118 Use this when you have two user_id columns that haven't been joined yet. 

119 This will join both User tables and apply the visibility checks. 

120 """ 

121 user1 = aliased(User) 

122 user2 = aliased(User) 

123 return ( 

124 self.join(user1, user1.id == column1) 

125 .join(user2, user2.id == column2) 

126 .where(user1.is_visible) 

127 .where(user2.is_visible) 

128 .where( 

129 ~exists( 

130 couchers_select(1) 

131 .select_from(UserBlock) 

132 .where( 

133 or_( 

134 and_(UserBlock.blocking_user_id == user1.id, UserBlock.blocked_user_id == user2.id), 

135 and_(UserBlock.blocking_user_id == user2.id, UserBlock.blocked_user_id == user1.id), 

136 ) 

137 ) 

138 ) 

139 ) 

140 ) 

141 

142 def where_moderated_content_visible_to_user_column( 

143 self, 

144 table: "type[ModeratedContent]", 

145 user_id_column: InstrumentedAttribute[int], 

146 is_list_operation: bool = False, 

147 ) -> Self: 

148 aliased_mod_state = aliased(ModerationState) 

149 conditions = [aliased_mod_state.visibility == ModerationVisibility.VISIBLE] 

150 

151 # UNLISTED content is visible in single-item operations but not in lists 

152 if not is_list_operation: 

153 conditions.append(aliased_mod_state.visibility == ModerationVisibility.UNLISTED) 

154 

155 # Authors can always see their own SHADOWED content 

156 conditions.append( 

157 and_( 

158 aliased_mod_state.visibility == ModerationVisibility.SHADOWED, 

159 getattr(table, table.__moderation_author_column__) == user_id_column, 

160 ) 

161 ) 

162 

163 return self.join(aliased_mod_state, aliased_mod_state.id == table.moderation_state_id).where(or_(*conditions)) 

164 

165 def where_moderated_content_visible( 

166 self, 

167 context: CouchersContext, 

168 table: "type[ModeratedContent]", 

169 is_list_operation: bool = False, 

170 ) -> Self: 

171 aliased_mod_state = aliased(ModerationState) 

172 conditions = [aliased_mod_state.visibility == ModerationVisibility.VISIBLE] 

173 

174 # UNLISTED content is visible in single-item operations but not in lists 

175 if not is_list_operation: 

176 conditions.append(aliased_mod_state.visibility == ModerationVisibility.UNLISTED) 

177 

178 # Authors can always see their own SHADOWED content 

179 if context.is_logged_in(): 

180 conditions.append( 

181 and_( 

182 aliased_mod_state.visibility == ModerationVisibility.SHADOWED, 

183 getattr(table, table.__moderation_author_column__) == context.user_id, 

184 ) 

185 ) 

186 

187 return self.join(aliased_mod_state, aliased_mod_state.id == table.moderation_state_id).where(or_(*conditions)) 

188 

189 def where_moderation_state_column_visible( 

190 self, 

191 context: CouchersContext, 

192 column: InstrumentedAttribute[int | None], 

193 ) -> Self: 

194 """ 

195 Filters based on whether the moderation state referenced by the column is visible. 

196 

197 Use this when you have a moderation_state_id column on a table that's not the moderated 

198 content itself (e.g., Notification.moderation_state_id). 

199 

200 The condition evaluates to True when: 

201 - The column is NULL (non-moderated content), OR 

202 - The linked content (HostRequest/GroupChat) is visible per where_moderated_content_visible 

203 

204 TODO: if you use this with a non-null column, check what's going on 

205 """ 

206 hr_visible = exists( 

207 couchers_select(HostRequest) 

208 .where(HostRequest.moderation_state_id == column) 

209 .where_moderated_content_visible(context, HostRequest) 

210 ) 

211 gc_visible = exists( 

212 couchers_select(GroupChat) 

213 .where(GroupChat.moderation_state_id == column) 

214 .where_moderated_content_visible(context, GroupChat) 

215 ) 

216 return self.where( 

217 or_( 

218 column == None, 

219 hr_visible, 

220 gc_visible, 

221 ) 

222 ) 

223 

224 

225def couchers_select(*expr: Any) -> CouchersSelect: 

226 return CouchersSelect(*expr) 

227 

228 

229def _relevant_user_blocks(user_id: int) -> CouchersSelect: 

230 """ 

231 Gets a list of blocked user IDs or users that have blocked this user: those should be hidden 

232 """ 

233 blocked_users = couchers_select(UserBlock.blocked_user_id).where(UserBlock.blocking_user_id == user_id) 

234 

235 blocking_users = couchers_select(UserBlock.blocking_user_id).where(UserBlock.blocked_user_id == user_id) 

236 

237 return couchers_select(union(blocked_users, blocking_users).subquery())