Coverage for src/couchers/db.py: 77%

116 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-14 00:52 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from collections.abc import Generator, Sequence 

6from contextlib import contextmanager 

7from os import getpid 

8from threading import get_ident 

9from typing import cast 

10 

11from alembic import command 

12from alembic.config import Config 

13from geoalchemy2 import WKBElement 

14from opentelemetry import trace 

15from sqlalchemy import Engine, Row, Subquery, create_engine, text 

16from sqlalchemy.orm.session import Session 

17from sqlalchemy.pool import QueuePool 

18from sqlalchemy.sql import and_, func, literal, or_ 

19 

20from couchers.config import config 

21from couchers.constants import SERVER_THREADS, WORKER_THREADS 

22from couchers.context import CouchersContext 

23from couchers.models import ( 

24 Cluster, 

25 ClusterRole, 

26 ClusterSubscription, 

27 FriendRelationship, 

28 FriendStatus, 

29 Node, 

30 TimezoneArea, 

31 User, 

32) 

33from couchers.sql import couchers_select as select 

34 

35logger = logging.getLogger(__name__) 

36 

37tracer = trace.get_tracer(__name__) 

38 

39 

40def apply_migrations() -> None: 

41 alembic_dir = os.path.dirname(__file__) + "/../.." 

42 cwd = os.getcwd() 

43 try: 

44 os.chdir(alembic_dir) 

45 alembic_cfg = Config("alembic.ini") 

46 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

47 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

48 command.upgrade(alembic_cfg, "head") 

49 finally: 

50 os.chdir(cwd) 

51 

52 

53@functools.cache 

54def _get_base_engine() -> Engine: 

55 return create_engine( 

56 config["DATABASE_CONNECTION_STRING"], 

57 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

58 # connection unexpectedly" errors 

59 pool_pre_ping=True, 

60 # one connection per thread 

61 poolclass=QueuePool, 

62 # main threads + a few extra in case 

63 pool_size=SERVER_THREADS + WORKER_THREADS + 12, 

64 ) 

65 

66 

67@contextmanager 

68def session_scope() -> Generator[Session]: 

69 with tracer.start_as_current_span("session_scope") as rollspan: 

70 with Session(_get_base_engine()) as session: 

71 session.begin() 

72 try: 

73 if logger.isEnabledFor(logging.DEBUG): 

74 try: 

75 frame = inspect.stack()[2] 

76 filename_line = f"{frame.filename}:{frame.lineno}" 

77 except Exception as e: 

78 filename_line = "{unknown file}" 

79 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

80 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

81 rollspan.set_attribute("db.backend_pid", backend_pid) 

82 rollspan.set_attribute("db.filename_line", filename_line) 

83 rollspan.set_attribute("rpc.thread", get_ident()) 

84 rollspan.set_attribute("rpc.pid", getpid()) 

85 

86 yield session 

87 session.commit() 

88 except: 

89 session.rollback() 

90 raise 

91 finally: 

92 if logger.isEnabledFor(logging.DEBUG): 

93 logger.debug(f"SScope: closed {backend_pid=}") 

94 

95 

96@contextmanager 

97def worker_repeatable_read_session_scope() -> Generator[Session]: 

98 """ 

99 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions, 

100 this causes two different connections to be used 

101 

102 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

103 background worker, effectively using postgres as a queueing system. 

104 """ 

105 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

106 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

107 session.begin() 

108 try: 

109 if logger.isEnabledFor(logging.DEBUG): 

110 try: 

111 frame = inspect.stack()[2] 

112 filename_line = f"{frame.filename}:{frame.lineno}" 

113 except Exception as e: 

114 filename_line = "{unknown file}" 

115 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

116 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

117 rollspan.set_attribute("db.backend_pid", backend_pid) 

118 rollspan.set_attribute("db.filename_line", filename_line) 

119 rollspan.set_attribute("rpc.thread", get_ident()) 

120 rollspan.set_attribute("rpc.pid", getpid()) 

121 

122 yield session 

123 session.commit() 

124 except: 

125 session.rollback() 

126 raise 

127 finally: 

128 if logger.isEnabledFor(logging.DEBUG): 

129 logger.debug(f"SScope (worker): closed {backend_pid=}") 

130 

131 

132def db_post_fork() -> None: 

133 """ 

134 Fix post-fork issues with sqlalchemy 

135 """ 

136 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

137 _get_base_engine().dispose(close=False) 

138 

139 

140def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool: 

141 return ( 

142 session.execute( 

143 select(FriendRelationship) 

144 .where_users_column_visible(context, FriendRelationship.from_user_id) 

145 .where_users_column_visible(context, FriendRelationship.to_user_id) 

146 .where( 

147 or_( 

148 and_( 

149 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user 

150 ), 

151 and_( 

152 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id 

153 ), 

154 ) 

155 ) 

156 .where(FriendRelationship.status == FriendStatus.accepted) 

157 ).scalar_one_or_none() 

158 is not None 

159 ) 

160 

161 

162def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None: 

163 """ 

164 Finds the smallest node containing the shape. 

165 

166 Shape can be any PostGIS geo object, e.g. output from create_coordinate 

167 """ 

168 

169 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

170 # must always be less than its parent Node, so no need to actually traverse the tree! 

171 return ( 

172 session.execute( 

173 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1) 

174 ) 

175 .scalars() 

176 .one_or_none() 

177 ) 

178 

179 

180def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery: 

181 parents = ( 

182 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

183 .where(Node.id == node_id) 

184 .cte("parents", recursive=True) 

185 ) 

186 

187 return select( 

188 parents.union( 

189 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

190 parents, Node.id == parents.c.parent_node_id 

191 ) 

192 ) 

193 ).subquery() 

194 

195 

196def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]: 

197 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

198 return session.execute( 

199 select(subquery, Cluster) 

200 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

201 .where(Cluster.is_official_cluster) 

202 .order_by(subquery.c.level.desc()) 

203 ).all() 

204 

205 

206def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool: 

207 query = select( 

208 ( 

209 select(True) 

210 .select_from(ClusterSubscription) 

211 .where(ClusterSubscription.role == ClusterRole.admin) 

212 .where(ClusterSubscription.user_id == user_id) 

213 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

214 ).exists() 

215 ) 

216 return cast(bool, session.execute(query).scalar_one()) 

217 

218 

219def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool: 

220 """ 

221 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

222 """ 

223 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

224 query = select( 

225 ( 

226 select(True) 

227 .select_from(ClusterSubscription) 

228 .where(ClusterSubscription.role == ClusterRole.admin) 

229 .where(ClusterSubscription.user_id == user_id) 

230 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

231 .where(Cluster.is_official_cluster) 

232 .where(Cluster.parent_node_id == subquery.c.id) 

233 ).exists() 

234 ) 

235 return cast(bool, session.execute(query).scalar_one()) 

236 

237 

238def can_moderate_at(session: Session, user_id: int, shape: WKBElement) -> bool: 

239 """ 

240 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

241 """ 

242 query = select( 

243 ( 

244 select(True) 

245 .select_from(ClusterSubscription) 

246 .where(ClusterSubscription.role == ClusterRole.admin) 

247 .where(ClusterSubscription.user_id == user_id) 

248 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

249 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id)) 

250 .where(func.ST_Contains(Node.geom, shape)) 

251 ).exists() 

252 ) 

253 return cast(bool, session.execute(query).scalar_one()) 

254 

255 

256def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool: 

257 """ 

258 Returns True if the user's location is geographically contained within the node's boundary. 

259 This is used to check if a user can leave a community - users cannot leave communities 

260 that contain their geographic location. 

261 """ 

262 query = select( 

263 ( 

264 select(True) 

265 .select_from(User) 

266 .join(Node, func.ST_Contains(Node.geom, User.geom)) 

267 .where(User.id == user_id) 

268 .where(Node.id == node_id) 

269 ).exists() 

270 ) 

271 return cast(bool, session.execute(query).scalar_one()) 

272 

273 

274def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None: 

275 area = session.execute( 

276 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

277 ).scalar_one_or_none() 

278 if area: 

279 return cast(str | None, area.tzid) 

280 return None