Coverage for src/couchers/db.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

78 statements  

1import functools 

2import logging 

3import os 

4from contextlib import contextmanager 

5from time import perf_counter_ns 

6 

7from alembic import command 

8from alembic.config import Config 

9from sqlalchemy import create_engine, event 

10from sqlalchemy.engine import Engine 

11from sqlalchemy.orm.session import Session 

12from sqlalchemy.pool import NullPool, SingletonThreadPool 

13from sqlalchemy.sql import and_, func, literal, or_ 

14 

15from couchers import config 

16from couchers.constants import SERVER_THREADS 

17from couchers.models import ( 

18 Cluster, 

19 ClusterRole, 

20 ClusterSubscription, 

21 FriendRelationship, 

22 FriendStatus, 

23 Node, 

24 TimezoneArea, 

25) 

26from couchers.profiler import add_sql_statement 

27from couchers.sql import couchers_select as select 

28from couchers.utils import now 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33def apply_migrations(): 

34 alembic_dir = os.path.dirname(__file__) + "/../.." 

35 cwd = os.getcwd() 

36 try: 

37 os.chdir(alembic_dir) 

38 alembic_cfg = Config("alembic.ini") 

39 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

40 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

41 command.upgrade(alembic_cfg, "head") 

42 finally: 

43 os.chdir(cwd) 

44 

45 

46@functools.lru_cache 

47def _get_base_engine(): 

48 if config.config["IN_TEST"]: 

49 pool_opts = {"poolclass": NullPool} 

50 else: 

51 pool_opts = { 

52 # one connection per thread 

53 "poolclass": SingletonThreadPool, 

54 # main threads + a couple for bg workers, etc 

55 "pool_size": SERVER_THREADS + 12, 

56 } 

57 

58 # `future` enables SQLalchemy 2.0 behaviour 

59 # `pool_pre_ping` checks that the connections in the pool are alive before using them, which avoids the "server 

60 # closed the connection unexpectedly" errors 

61 return create_engine( 

62 config.config["DATABASE_CONNECTION_STRING"], 

63 future=True, 

64 pool_pre_ping=True, 

65 **pool_opts, 

66 ) 

67 

68 

69def clear_base_engine_cache(): 

70 """ 

71 This needs to be done when the public schema is dropped. 

72 """ 

73 _get_base_engine.cache_clear() 

74 

75 

76def get_engine(isolation_level=None): 

77 """ 

78 Creates an engine with the given isolation level. 

79 """ 

80 # creates a shallow copy with the given isolation level 

81 if not isolation_level: 

82 return _get_base_engine() 

83 else: 

84 return _get_base_engine().execution_options(isolation_level=isolation_level) 

85 

86 

87@contextmanager 

88def session_scope(isolation_level=None): 

89 session = Session(get_engine(isolation_level=isolation_level), future=True) 

90 try: 

91 yield session 

92 session.commit() 

93 except: 

94 session.rollback() 

95 raise 

96 finally: 

97 session.close() 

98 

99 

100@event.listens_for(Engine, "before_cursor_execute") 

101def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

102 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns())) 

103 

104 

105@event.listens_for(Engine, "after_cursor_execute") 

106def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

107 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1) 

108 end, end_ns = now(), perf_counter_ns() 

109 add_sql_statement(statement, parameters, start, start_ns, end, end_ns) 

110 

111 

112def are_friends(session, context, other_user): 

113 return ( 

114 session.execute( 

115 select(FriendRelationship) 

116 .where_users_column_visible(context, FriendRelationship.from_user_id) 

117 .where_users_column_visible(context, FriendRelationship.to_user_id) 

118 .where( 

119 or_( 

120 and_( 

121 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user 

122 ), 

123 and_( 

124 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id 

125 ), 

126 ) 

127 ) 

128 .where(FriendRelationship.status == FriendStatus.accepted) 

129 ).scalar_one_or_none() 

130 is not None 

131 ) 

132 

133 

134def get_parent_node_at_location(session, shape): 

135 """ 

136 Finds the smallest node containing the shape. 

137 

138 Shape can be any PostGIS geo object, e.g. output from create_coordinate 

139 """ 

140 

141 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

142 # must always be less than its parent Node, so no need to actually traverse the tree! 

143 return ( 

144 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom))) 

145 .scalars() 

146 .first() 

147 ) 

148 

149 

150def get_node_parents_recursively(session, node_id): 

151 """ 

152 Gets the upwards hierarchy of parents, ordered by level, for a given node 

153 

154 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster) 

155 """ 

156 parents = ( 

157 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

158 .where(Node.id == node_id) 

159 .cte("parents", recursive=True) 

160 ) 

161 

162 subquery = select( 

163 parents.union( 

164 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

165 parents, Node.id == parents.c.parent_node_id 

166 ) 

167 ) 

168 ).subquery() 

169 

170 return session.execute( 

171 select(subquery, Cluster) 

172 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

173 .where(Cluster.is_official_cluster) 

174 .order_by(subquery.c.level.desc()) 

175 ).all() 

176 

177 

178def _can_moderate_any_cluster(session, user_id, cluster_ids): 

179 return ( 

180 session.execute( 

181 select(func.count()) 

182 .select_from(ClusterSubscription) 

183 .where(ClusterSubscription.role == ClusterRole.admin) 

184 .where(ClusterSubscription.user_id == user_id) 

185 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

186 ).scalar_one() 

187 > 0 

188 ) 

189 

190 

191def can_moderate_at(session, user_id, shape): 

192 """ 

193 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

194 """ 

195 cluster_ids = [ 

196 cluster_id 

197 for (cluster_id,) in session.execute( 

198 select(Cluster.id) 

199 .join(Node, Node.id == Cluster.parent_node_id) 

200 .where(Cluster.is_official_cluster) 

201 .where(func.ST_Contains(Node.geom, shape)) 

202 ).all() 

203 ] 

204 return _can_moderate_any_cluster(session, user_id, cluster_ids) 

205 

206 

207def can_moderate_node(session, user_id, node_id): 

208 """ 

209 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

210 """ 

211 return _can_moderate_any_cluster( 

212 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)] 

213 ) 

214 

215 

216def timezone_at_coordinate(session, geom): 

217 area = session.execute( 

218 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

219 ).scalar_one_or_none() 

220 if area: 

221 return area.tzid 

222 return None