Coverage for src/couchers/db.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

76 statements  

1import functools 

2import logging 

3import os 

4from contextlib import contextmanager 

5from time import perf_counter_ns 

6 

7from alembic import command 

8from alembic.config import Config 

9from sqlalchemy import create_engine, event 

10from sqlalchemy.engine import Engine 

11from sqlalchemy.orm.session import Session 

12from sqlalchemy.pool import NullPool, SingletonThreadPool 

13from sqlalchemy.sql import and_, func, literal, or_ 

14 

15from couchers import config 

16from couchers.constants import SERVER_THREADS 

17from couchers.models import ( 

18 Cluster, 

19 ClusterRole, 

20 ClusterSubscription, 

21 FriendRelationship, 

22 FriendStatus, 

23 Node, 

24 TimezoneArea, 

25) 

26from couchers.profiler import add_sql_statement 

27from couchers.sql import couchers_select as select 

28from couchers.utils import now 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33def apply_migrations(): 

34 alembic_dir = os.path.dirname(__file__) + "/../.." 

35 cwd = os.getcwd() 

36 try: 

37 os.chdir(alembic_dir) 

38 alembic_cfg = Config("alembic.ini") 

39 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

40 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

41 command.upgrade(alembic_cfg, "head") 

42 finally: 

43 os.chdir(cwd) 

44 

45 

46@functools.lru_cache 

47def _get_base_engine(): 

48 if config.config["IN_TEST"]: 

49 pool_opts = {"poolclass": NullPool} 

50 else: 

51 pool_opts = { 

52 # one connection per thread 

53 "poolclass": SingletonThreadPool, 

54 # main threads + a couple for bg workers, etc 

55 "pool_size": SERVER_THREADS + 12, 

56 } 

57 

58 # `future` enables SQLalchemy 2.0 behaviour 

59 # `pool_pre_ping` checks that the connections in the pool are alive before using them, which avoids the "server 

60 # closed the connection unexpectedly" errors 

61 return create_engine( 

62 config.config["DATABASE_CONNECTION_STRING"], 

63 future=True, 

64 pool_pre_ping=True, 

65 **pool_opts, 

66 ) 

67 

68 

69def get_engine(isolation_level=None): 

70 """ 

71 Creates an engine with the given isolation level. 

72 """ 

73 # creates a shallow copy with the given isolation level 

74 if not isolation_level: 

75 return _get_base_engine() 

76 else: 

77 return _get_base_engine().execution_options(isolation_level=isolation_level) 

78 

79 

80@contextmanager 

81def session_scope(isolation_level=None): 

82 session = Session(get_engine(isolation_level=isolation_level), future=True) 

83 try: 

84 yield session 

85 session.commit() 

86 except: 

87 session.rollback() 

88 raise 

89 finally: 

90 session.close() 

91 

92 

93@event.listens_for(Engine, "before_cursor_execute") 

94def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

95 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns())) 

96 

97 

98@event.listens_for(Engine, "after_cursor_execute") 

99def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

100 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1) 

101 end, end_ns = now(), perf_counter_ns() 

102 add_sql_statement(statement, parameters, start, start_ns, end, end_ns) 

103 

104 

105def are_friends(session, context, other_user): 

106 return ( 

107 session.execute( 

108 select(FriendRelationship) 

109 .where_users_column_visible(context, FriendRelationship.from_user_id) 

110 .where_users_column_visible(context, FriendRelationship.to_user_id) 

111 .where( 

112 or_( 

113 and_( 

114 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user 

115 ), 

116 and_( 

117 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id 

118 ), 

119 ) 

120 ) 

121 .where(FriendRelationship.status == FriendStatus.accepted) 

122 ).scalar_one_or_none() 

123 is not None 

124 ) 

125 

126 

127def get_parent_node_at_location(session, shape): 

128 """ 

129 Finds the smallest node containing the shape. 

130 

131 Shape can be any PostGIS geo object, e.g. output from create_coordinate 

132 """ 

133 

134 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

135 # must always be less than its parent Node, so no need to actually traverse the tree! 

136 return ( 

137 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom))) 

138 .scalars() 

139 .first() 

140 ) 

141 

142 

143def get_node_parents_recursively(session, node_id): 

144 """ 

145 Gets the upwards hierarchy of parents, ordered by level, for a given node 

146 

147 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster) 

148 """ 

149 parents = ( 

150 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

151 .where(Node.id == node_id) 

152 .cte("parents", recursive=True) 

153 ) 

154 

155 subquery = select( 

156 parents.union( 

157 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

158 parents, Node.id == parents.c.parent_node_id 

159 ) 

160 ) 

161 ).subquery() 

162 

163 return session.execute( 

164 select(subquery, Cluster) 

165 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

166 .where(Cluster.is_official_cluster) 

167 .order_by(subquery.c.level.desc()) 

168 ).all() 

169 

170 

171def _can_moderate_any_cluster(session, user_id, cluster_ids): 

172 return ( 

173 session.execute( 

174 select(func.count()) 

175 .select_from(ClusterSubscription) 

176 .where(ClusterSubscription.role == ClusterRole.admin) 

177 .where(ClusterSubscription.user_id == user_id) 

178 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

179 ).scalar_one() 

180 > 0 

181 ) 

182 

183 

184def can_moderate_at(session, user_id, shape): 

185 """ 

186 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

187 """ 

188 cluster_ids = [ 

189 cluster_id 

190 for (cluster_id,) in session.execute( 

191 select(Cluster.id) 

192 .join(Node, Node.id == Cluster.parent_node_id) 

193 .where(Cluster.is_official_cluster) 

194 .where(func.ST_Contains(Node.geom, shape)) 

195 ).all() 

196 ] 

197 return _can_moderate_any_cluster(session, user_id, cluster_ids) 

198 

199 

200def can_moderate_node(session, user_id, node_id): 

201 """ 

202 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

203 """ 

204 return _can_moderate_any_cluster( 

205 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)] 

206 ) 

207 

208 

209def timezone_at_coordinate(session, geom): 

210 area = session.execute( 

211 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

212 ).scalar_one_or_none() 

213 if area: 

214 return area.tzid 

215 return None