Coverage for src/couchers/sql.py: 98%

40 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-11-21 04:21 +0000

1from sqlalchemy.orm import aliased 

2from sqlalchemy.sql import Select, union 

3 

4from couchers.models import User, UserBlock 

5from couchers.utils import is_valid_email, is_valid_user_id, is_valid_username 

6 

7 

8def _relevant_user_blocks(user_id): 

9 """ 

10 Gets list of blocked user IDs or users that have blocked this user: those should be hidden 

11 """ 

12 blocked_users = couchers_select(UserBlock.blocked_user_id).where(UserBlock.blocking_user_id == user_id) 

13 

14 blocking_users = couchers_select(UserBlock.blocking_user_id).where(UserBlock.blocked_user_id == user_id) 

15 

16 return couchers_select(union(blocked_users, blocking_users).subquery()) 

17 

18 

19""" 

20This method construct provided directly by the developers 

21They intend to implement a better option in the near future 

22See issue here: https://github.com/sqlalchemy/sqlalchemy/issues/6700 

23""" 

24 

25 

26def couchers_select(*expr): 

27 return CouchersSelect(*expr) 

28 

29 

30class CouchersSelect(Select): 

31 inherit_cache = True 

32 

33 def where_username_or_email(self, field, table=User): 

34 if is_valid_username(field): 

35 return self.where(table.username == field) 

36 elif is_valid_email(field): 

37 return self.where(table.email == field) 

38 # no fields match, this will return no rows 

39 return self.where(False) 

40 

41 def where_username_or_id(self, field, table=User): 

42 if is_valid_username(field): 

43 return self.where(table.username == field) 

44 elif is_valid_user_id(field): 

45 return self.where(table.id == field) 

46 # no fields match, this will return no rows 

47 return self.where(False) 

48 

49 def where_username_or_email_or_id(self, field): 

50 # Should only be used for admin APIs, etc. 

51 if is_valid_username(field): 

52 return self.where(User.username == field) 

53 elif is_valid_email(field): 

54 return self.where(User.email == field) 

55 elif is_valid_user_id(field): 

56 return self.where(User.id == field) 

57 # no fields match, this will return no rows 

58 return self.where(False) 

59 

60 def where_users_visible(self, context, table=User): 

61 """ 

62 Filters out users that should not be visible: blocked, deleted, or banned 

63 

64 Filters the given table, assuming it's already joined/selected from 

65 """ 

66 hidden_users = _relevant_user_blocks(context.user_id) 

67 return self.where(table.is_visible).where(~table.id.in_(hidden_users)) 

68 

69 def where_users_column_visible(self, context, column): 

70 """ 

71 Filters the given column, not yet joined/selected from 

72 """ 

73 hidden_users = _relevant_user_blocks(context.user_id) 

74 aliased_user = aliased(User) 

75 return ( 

76 self.join(aliased_user, aliased_user.id == column) 

77 .where(aliased_user.is_visible) 

78 .where(~aliased_user.id.in_(hidden_users)) 

79 )