Coverage for src/couchers/db.py: 78%

116 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-10-15 13:03 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from contextlib import contextmanager 

6from os import getpid 

7from threading import get_ident 

8from time import perf_counter_ns 

9 

10from alembic import command 

11from alembic.config import Config 

12from opentelemetry import trace 

13from sqlalchemy import create_engine, event, text 

14from sqlalchemy.engine import Engine 

15from sqlalchemy.orm.session import Session 

16from sqlalchemy.pool import QueuePool 

17from sqlalchemy.sql import and_, func, literal, or_ 

18 

19from couchers.config import config 

20from couchers.constants import SERVER_THREADS, WORKER_THREADS 

21from couchers.models import ( 

22 Cluster, 

23 ClusterRole, 

24 ClusterSubscription, 

25 FriendRelationship, 

26 FriendStatus, 

27 Node, 

28 TimezoneArea, 

29) 

30from couchers.profiler import add_sql_statement 

31from couchers.sql import couchers_select as select 

32from couchers.utils import now 

33 

34logger = logging.getLogger(__name__) 

35 

36tracer = trace.get_tracer(__name__) 

37 

38 

39def apply_migrations(): 

40 alembic_dir = os.path.dirname(__file__) + "/../.." 

41 cwd = os.getcwd() 

42 try: 

43 os.chdir(alembic_dir) 

44 alembic_cfg = Config("alembic.ini") 

45 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

46 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

47 command.upgrade(alembic_cfg, "head") 

48 finally: 

49 os.chdir(cwd) 

50 

51 

52@functools.cache 

53def _get_base_engine(): 

54 return create_engine( 

55 config["DATABASE_CONNECTION_STRING"], 

56 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

57 # connection unexpectedly" errors 

58 pool_pre_ping=True, 

59 # one connection per thread 

60 poolclass=QueuePool, 

61 # main threads + a few extra in case 

62 pool_size=SERVER_THREADS + WORKER_THREADS + 12, 

63 ) 

64 

65 

66@contextmanager 

67def session_scope(): 

68 with tracer.start_as_current_span("session_scope") as rollspan: 

69 with Session(_get_base_engine()) as session: 

70 session.begin() 

71 try: 

72 if logger.isEnabledFor(logging.DEBUG): 

73 try: 

74 frame = inspect.stack()[2] 

75 filename_line = f"{frame.filename}:{frame.lineno}" 

76 except Exception as e: 

77 filename_line = "{unknown file}" 

78 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar() 

79 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

80 rollspan.set_attribute("db.backend_pid", backend_pid) 

81 rollspan.set_attribute("db.filename_line", filename_line) 

82 rollspan.set_attribute("rpc.thread", get_ident()) 

83 rollspan.set_attribute("rpc.pid", getpid()) 

84 

85 yield session 

86 session.commit() 

87 except: 

88 session.rollback() 

89 raise 

90 finally: 

91 if logger.isEnabledFor(logging.DEBUG): 

92 logger.debug(f"SScope: closed {backend_pid=}") 

93 

94 

95@contextmanager 

96def worker_repeatable_read_session_scope(): 

97 """ 

98 This is a separate sesson scope that is isolated from the main one since otherwise we end up nesting transactions, 

99 this causes two different connections to be used 

100 

101 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

102 background worker, effectively using postgres as a queueing system. 

103 """ 

104 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

105 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

106 session.begin() 

107 try: 

108 if logger.isEnabledFor(logging.DEBUG): 

109 try: 

110 frame = inspect.stack()[2] 

111 filename_line = f"{frame.filename}:{frame.lineno}" 

112 except Exception as e: 

113 filename_line = "{unknown file}" 

114 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar() 

115 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

116 rollspan.set_attribute("db.backend_pid", backend_pid) 

117 rollspan.set_attribute("db.filename_line", filename_line) 

118 rollspan.set_attribute("rpc.thread", get_ident()) 

119 rollspan.set_attribute("rpc.pid", getpid()) 

120 

121 yield session 

122 session.commit() 

123 except: 

124 session.rollback() 

125 raise 

126 finally: 

127 if logger.isEnabledFor(logging.DEBUG): 

128 logger.debug(f"SScope (worker): closed {backend_pid=}") 

129 

130 

131def db_post_fork(): 

132 """ 

133 Fix post-fork issues with sqlalchemy 

134 """ 

135 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

136 _get_base_engine().dispose(close=False) 

137 

138 

139@event.listens_for(Engine, "before_cursor_execute") 

140def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

141 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns())) 

142 

143 

144@event.listens_for(Engine, "after_cursor_execute") 

145def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): 

146 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1) 

147 end, end_ns = now(), perf_counter_ns() 

148 add_sql_statement(statement, parameters, start, start_ns, end, end_ns) 

149 

150 

151def are_friends(session, context, other_user): 

152 return ( 

153 session.execute( 

154 select(FriendRelationship) 

155 .where_users_column_visible(context, FriendRelationship.from_user_id) 

156 .where_users_column_visible(context, FriendRelationship.to_user_id) 

157 .where( 

158 or_( 

159 and_( 

160 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user 

161 ), 

162 and_( 

163 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id 

164 ), 

165 ) 

166 ) 

167 .where(FriendRelationship.status == FriendStatus.accepted) 

168 ).scalar_one_or_none() 

169 is not None 

170 ) 

171 

172 

173def get_parent_node_at_location(session, shape): 

174 """ 

175 Finds the smallest node containing the shape. 

176 

177 Shape can be any PostGIS geo object, e.g. output from create_coordinate 

178 """ 

179 

180 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

181 # must always be less than its parent Node, so no need to actually traverse the tree! 

182 return ( 

183 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom))) 

184 .scalars() 

185 .first() 

186 ) 

187 

188 

189def get_node_parents_recursively(session, node_id): 

190 """ 

191 Gets the upwards hierarchy of parents, ordered by level, for a given node 

192 

193 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster) 

194 """ 

195 parents = ( 

196 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

197 .where(Node.id == node_id) 

198 .cte("parents", recursive=True) 

199 ) 

200 

201 subquery = select( 

202 parents.union( 

203 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

204 parents, Node.id == parents.c.parent_node_id 

205 ) 

206 ) 

207 ).subquery() 

208 

209 return session.execute( 

210 select(subquery, Cluster) 

211 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

212 .where(Cluster.is_official_cluster) 

213 .order_by(subquery.c.level.desc()) 

214 ).all() 

215 

216 

217def _can_moderate_any_cluster(session, user_id, cluster_ids): 

218 return ( 

219 session.execute( 

220 select(func.count()) 

221 .select_from(ClusterSubscription) 

222 .where(ClusterSubscription.role == ClusterRole.admin) 

223 .where(ClusterSubscription.user_id == user_id) 

224 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

225 ).scalar_one() 

226 > 0 

227 ) 

228 

229 

230def can_moderate_at(session, user_id, shape): 

231 """ 

232 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

233 """ 

234 cluster_ids = [ 

235 cluster_id 

236 for (cluster_id,) in session.execute( 

237 select(Cluster.id) 

238 .join(Node, Node.id == Cluster.parent_node_id) 

239 .where(Cluster.is_official_cluster) 

240 .where(func.ST_Contains(Node.geom, shape)) 

241 ).all() 

242 ] 

243 return _can_moderate_any_cluster(session, user_id, cluster_ids) 

244 

245 

246def can_moderate_node(session, user_id, node_id): 

247 """ 

248 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

249 """ 

250 return _can_moderate_any_cluster( 

251 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)] 

252 ) 

253 

254 

255def timezone_at_coordinate(session, geom): 

256 area = session.execute( 

257 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

258 ).scalar_one_or_none() 

259 if area: 

260 return area.tzid 

261 return None