Coverage for src/couchers/db.py: 75%

104 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-12-20 18:03 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from contextlib import contextmanager 

6from os import getpid 

7from threading import get_ident 

8 

9from alembic import command 

10from alembic.config import Config 

11from opentelemetry import trace 

12from sqlalchemy import create_engine, text 

13from sqlalchemy.orm.session import Session 

14from sqlalchemy.pool import QueuePool 

15from sqlalchemy.sql import and_, func, literal, or_ 

16 

17from couchers.config import config 

18from couchers.constants import SERVER_THREADS, WORKER_THREADS 

19from couchers.models import ( 

20 Cluster, 

21 ClusterRole, 

22 ClusterSubscription, 

23 FriendRelationship, 

24 FriendStatus, 

25 Node, 

26 TimezoneArea, 

27) 

28from couchers.sql import couchers_select as select 

29 

30logger = logging.getLogger(__name__) 

31 

32tracer = trace.get_tracer(__name__) 

33 

34 

35def apply_migrations(): 

36 alembic_dir = os.path.dirname(__file__) + "/../.." 

37 cwd = os.getcwd() 

38 try: 

39 os.chdir(alembic_dir) 

40 alembic_cfg = Config("alembic.ini") 

41 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

42 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

43 command.upgrade(alembic_cfg, "head") 

44 finally: 

45 os.chdir(cwd) 

46 

47 

48@functools.cache 

49def _get_base_engine(): 

50 return create_engine( 

51 config["DATABASE_CONNECTION_STRING"], 

52 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

53 # connection unexpectedly" errors 

54 pool_pre_ping=True, 

55 # one connection per thread 

56 poolclass=QueuePool, 

57 # main threads + a few extra in case 

58 pool_size=SERVER_THREADS + WORKER_THREADS + 12, 

59 ) 

60 

61 

62@contextmanager 

63def session_scope(): 

64 with tracer.start_as_current_span("session_scope") as rollspan: 

65 with Session(_get_base_engine()) as session: 

66 session.begin() 

67 try: 

68 if logger.isEnabledFor(logging.DEBUG): 

69 try: 

70 frame = inspect.stack()[2] 

71 filename_line = f"{frame.filename}:{frame.lineno}" 

72 except Exception as e: 

73 filename_line = "{unknown file}" 

74 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar() 

75 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

76 rollspan.set_attribute("db.backend_pid", backend_pid) 

77 rollspan.set_attribute("db.filename_line", filename_line) 

78 rollspan.set_attribute("rpc.thread", get_ident()) 

79 rollspan.set_attribute("rpc.pid", getpid()) 

80 

81 yield session 

82 session.commit() 

83 except: 

84 session.rollback() 

85 raise 

86 finally: 

87 if logger.isEnabledFor(logging.DEBUG): 

88 logger.debug(f"SScope: closed {backend_pid=}") 

89 

90 

91@contextmanager 

92def worker_repeatable_read_session_scope(): 

93 """ 

94 This is a separate sesson scope that is isolated from the main one since otherwise we end up nesting transactions, 

95 this causes two different connections to be used 

96 

97 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

98 background worker, effectively using postgres as a queueing system. 

99 """ 

100 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

101 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

102 session.begin() 

103 try: 

104 if logger.isEnabledFor(logging.DEBUG): 

105 try: 

106 frame = inspect.stack()[2] 

107 filename_line = f"{frame.filename}:{frame.lineno}" 

108 except Exception as e: 

109 filename_line = "{unknown file}" 

110 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar() 

111 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

112 rollspan.set_attribute("db.backend_pid", backend_pid) 

113 rollspan.set_attribute("db.filename_line", filename_line) 

114 rollspan.set_attribute("rpc.thread", get_ident()) 

115 rollspan.set_attribute("rpc.pid", getpid()) 

116 

117 yield session 

118 session.commit() 

119 except: 

120 session.rollback() 

121 raise 

122 finally: 

123 if logger.isEnabledFor(logging.DEBUG): 

124 logger.debug(f"SScope (worker): closed {backend_pid=}") 

125 

126 

127def db_post_fork(): 

128 """ 

129 Fix post-fork issues with sqlalchemy 

130 """ 

131 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

132 _get_base_engine().dispose(close=False) 

133 

134 

135def are_friends(session, context, other_user): 

136 return ( 

137 session.execute( 

138 select(FriendRelationship) 

139 .where_users_column_visible(context, FriendRelationship.from_user_id) 

140 .where_users_column_visible(context, FriendRelationship.to_user_id) 

141 .where( 

142 or_( 

143 and_( 

144 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user 

145 ), 

146 and_( 

147 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id 

148 ), 

149 ) 

150 ) 

151 .where(FriendRelationship.status == FriendStatus.accepted) 

152 ).scalar_one_or_none() 

153 is not None 

154 ) 

155 

156 

157def get_parent_node_at_location(session, shape): 

158 """ 

159 Finds the smallest node containing the shape. 

160 

161 Shape can be any PostGIS geo object, e.g. output from create_coordinate 

162 """ 

163 

164 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

165 # must always be less than its parent Node, so no need to actually traverse the tree! 

166 return ( 

167 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom))) 

168 .scalars() 

169 .first() 

170 ) 

171 

172 

173def get_node_parents_recursively(session, node_id): 

174 """ 

175 Gets the upwards hierarchy of parents, ordered by level, for a given node 

176 

177 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster) 

178 """ 

179 parents = ( 

180 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

181 .where(Node.id == node_id) 

182 .cte("parents", recursive=True) 

183 ) 

184 

185 subquery = select( 

186 parents.union( 

187 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

188 parents, Node.id == parents.c.parent_node_id 

189 ) 

190 ) 

191 ).subquery() 

192 

193 return session.execute( 

194 select(subquery, Cluster) 

195 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

196 .where(Cluster.is_official_cluster) 

197 .order_by(subquery.c.level.desc()) 

198 ).all() 

199 

200 

201def _can_moderate_any_cluster(session, user_id, cluster_ids): 

202 return ( 

203 session.execute( 

204 select(func.count()) 

205 .select_from(ClusterSubscription) 

206 .where(ClusterSubscription.role == ClusterRole.admin) 

207 .where(ClusterSubscription.user_id == user_id) 

208 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

209 ).scalar_one() 

210 > 0 

211 ) 

212 

213 

214def can_moderate_at(session, user_id, shape): 

215 """ 

216 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

217 """ 

218 cluster_ids = [ 

219 cluster_id 

220 for (cluster_id,) in session.execute( 

221 select(Cluster.id) 

222 .join(Node, Node.id == Cluster.parent_node_id) 

223 .where(Cluster.is_official_cluster) 

224 .where(func.ST_Contains(Node.geom, shape)) 

225 ).all() 

226 ] 

227 return _can_moderate_any_cluster(session, user_id, cluster_ids) 

228 

229 

230def can_moderate_node(session, user_id, node_id): 

231 """ 

232 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

233 """ 

234 return _can_moderate_any_cluster( 

235 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)] 

236 ) 

237 

238 

239def timezone_at_coordinate(session, geom): 

240 area = session.execute( 

241 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

242 ).scalar_one_or_none() 

243 if area: 

244 return area.tzid 

245 return None