Coverage for src/couchers/sql.py: 98%
40 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
1from sqlalchemy.orm import aliased
2from sqlalchemy.sql import Select, union
4from couchers.models import User, UserBlock
5from couchers.utils import is_valid_email, is_valid_user_id, is_valid_username
8def _relevant_user_blocks(user_id):
9 """
10 Gets list of blocked user IDs or users that have blocked this user: those should be hidden
11 """
12 blocked_users = couchers_select(UserBlock.blocked_user_id).where(UserBlock.blocking_user_id == user_id)
14 blocking_users = couchers_select(UserBlock.blocking_user_id).where(UserBlock.blocked_user_id == user_id)
16 return couchers_select(union(blocked_users, blocking_users).subquery())
19"""
20This method construct provided directly by the developers
21They intend to implement a better option in the near future
22See issue here: https://github.com/sqlalchemy/sqlalchemy/issues/6700
23"""
26def couchers_select(*expr):
27 return CouchersSelect(*expr)
30class CouchersSelect(Select):
31 inherit_cache = True
33 def where_username_or_email(self, field, table=User):
34 if is_valid_username(field):
35 return self.where(table.username == field)
36 elif is_valid_email(field):
37 return self.where(table.email == field)
38 # no fields match, this will return no rows
39 return self.where(False)
41 def where_username_or_id(self, field, table=User):
42 if is_valid_username(field):
43 return self.where(table.username == field)
44 elif is_valid_user_id(field):
45 return self.where(table.id == field)
46 # no fields match, this will return no rows
47 return self.where(False)
49 def where_username_or_email_or_id(self, field):
50 # Should only be used for admin APIs, etc.
51 if is_valid_username(field):
52 return self.where(User.username == field)
53 elif is_valid_email(field):
54 return self.where(User.email == field)
55 elif is_valid_user_id(field):
56 return self.where(User.id == field)
57 # no fields match, this will return no rows
58 return self.where(False)
60 def where_users_visible(self, context, table=User):
61 """
62 Filters out users that should not be visible: blocked, deleted, or banned
64 Filters the given table, assuming it's already joined/selected from
65 """
66 hidden_users = _relevant_user_blocks(context.user_id)
67 return self.where(table.is_visible).where(~table.id.in_(hidden_users))
69 def where_users_column_visible(self, context, column):
70 """
71 Filters the given column, not yet joined/selected from
72 """
73 hidden_users = _relevant_user_blocks(context.user_id)
74 aliased_user = aliased(User)
75 return (
76 self.join(aliased_user, aliased_user.id == column)
77 .where(aliased_user.is_visible)
78 .where(~aliased_user.id.in_(hidden_users))
79 )