Coverage for app / backend / src / couchers / db.py: 77%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 09:44 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from collections.abc import Generator, Sequence 

6from contextlib import contextmanager 

7from os import getpid 

8from threading import get_ident 

9 

10from alembic import command 

11from alembic.config import Config 

12from geoalchemy2 import WKBElement 

13from opentelemetry import trace 

14from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true 

15from sqlalchemy.dialects import registry 

16from sqlalchemy.orm.session import Session 

17from sqlalchemy.pool import QueuePool 

18from sqlalchemy.sql import and_, func, literal, or_ 

19 

20from couchers.config import config 

21from couchers.constants import SERVER_THREADS, WORKER_THREADS 

22from couchers.context import CouchersContext 

23from couchers.models import ( 

24 Cluster, 

25 ClusterRole, 

26 ClusterSubscription, 

27 FriendRelationship, 

28 FriendStatus, 

29 Geom, 

30 Node, 

31 TimezoneArea, 

32 User, 

33) 

34from couchers.sql import where_users_column_visible 

35 

36# Register psycopg (psycopg3) as the default driver for postgresql:// URLs 

37# This must happen before any engine is created 

38registry.register("postgresql", "sqlalchemy.dialects.postgresql.psycopg", "PGDialect_psycopg") 

39 

40logger = logging.getLogger(__name__) 

41 

42tracer = trace.get_tracer(__name__) 

43 

44 

45def apply_migrations() -> None: 

46 alembic_dir = os.path.dirname(__file__) + "/../.." 

47 cwd = os.getcwd() 

48 try: 

49 os.chdir(alembic_dir) 

50 alembic_cfg = Config("alembic.ini") 

51 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

52 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

53 command.upgrade(alembic_cfg, "head") 

54 finally: 

55 os.chdir(cwd) 

56 

57 

58@functools.cache 

59def _get_base_engine() -> Engine: 

60 return create_engine( 

61 config["DATABASE_CONNECTION_STRING"], 

62 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

63 # connection unexpectedly" errors 

64 pool_pre_ping=True, 

65 # one connection per thread 

66 poolclass=QueuePool, 

67 # main threads + a few extra in case 

68 pool_size=SERVER_THREADS + WORKER_THREADS + 12, 

69 ) 

70 

71 

72@contextmanager 

73def session_scope() -> Generator[Session]: 

74 with tracer.start_as_current_span("session_scope") as rollspan: 

75 with Session(_get_base_engine()) as session: 

76 session.begin() 

77 try: 

78 if logger.isEnabledFor(logging.DEBUG): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 try: 

80 frame = inspect.stack()[2] 

81 filename_line = f"{frame.filename}:{frame.lineno}" 

82 except Exception as e: 

83 filename_line = "{unknown file}" 

84 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

85 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

86 rollspan.set_attribute("db.backend_pid", backend_pid) 

87 rollspan.set_attribute("db.filename_line", filename_line) 

88 rollspan.set_attribute("rpc.thread", get_ident()) 

89 rollspan.set_attribute("rpc.pid", getpid()) 

90 

91 yield session 

92 session.commit() 

93 except: 

94 session.rollback() 

95 raise 

96 finally: 

97 if logger.isEnabledFor(logging.DEBUG): 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 logger.debug(f"SScope: closed {backend_pid=}") 

99 

100 

101@contextmanager 

102def worker_repeatable_read_session_scope() -> Generator[Session]: 

103 """ 

104 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions, 

105 this causes two different connections to be used 

106 

107 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

108 background worker, effectively using postgres as a queueing system. 

109 """ 

110 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

111 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

112 session.begin() 

113 try: 

114 if logger.isEnabledFor(logging.DEBUG): 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 try: 

116 frame = inspect.stack()[2] 

117 filename_line = f"{frame.filename}:{frame.lineno}" 

118 except Exception as e: 

119 filename_line = "{unknown file}" 

120 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

121 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

122 rollspan.set_attribute("db.backend_pid", backend_pid) 

123 rollspan.set_attribute("db.filename_line", filename_line) 

124 rollspan.set_attribute("rpc.thread", get_ident()) 

125 rollspan.set_attribute("rpc.pid", getpid()) 

126 

127 yield session 

128 session.commit() 

129 except: 

130 session.rollback() 

131 raise 

132 finally: 

133 if logger.isEnabledFor(logging.DEBUG): 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 logger.debug(f"SScope (worker): closed {backend_pid=}") 

135 

136 

137def db_post_fork() -> None: 

138 """ 

139 Fix post-fork issues with sqlalchemy 

140 """ 

141 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

142 _get_base_engine().dispose(close=False) 

143 

144 

145def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool: 

146 query = select(FriendRelationship) 

147 query = where_users_column_visible(query, context, FriendRelationship.from_user_id) 

148 query = where_users_column_visible(query, context, FriendRelationship.to_user_id) 

149 query = query.where( 

150 or_( 

151 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user), 

152 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id), 

153 ) 

154 ).where(FriendRelationship.status == FriendStatus.accepted) 

155 return session.execute(query).scalar_one_or_none() is not None 

156 

157 

158def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None: 

159 """ 

160 Finds the smallest node containing the shape. 

161 

162 Shape can be any PostGIS geo object, e.g., output from create_coordinate 

163 """ 

164 

165 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

166 # must always be less than its parent Node, so no need to actually traverse the tree! 

167 return ( 

168 session.execute( 

169 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1) 

170 ) 

171 .scalars() 

172 .one_or_none() 

173 ) 

174 

175 

176def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery: 

177 parents = ( 

178 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

179 .where(Node.id == node_id) 

180 .cte("parents", recursive=True) 

181 ) 

182 

183 return select( 

184 parents.union( 

185 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

186 parents, Node.id == parents.c.parent_node_id 

187 ) 

188 ) 

189 ).subquery() 

190 

191 

192def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]: 

193 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

194 return session.execute( 

195 select(subquery, Cluster) 

196 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

197 .where(Cluster.is_official_cluster) 

198 .order_by(subquery.c.level.desc()) 

199 ).all() 

200 

201 

202def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool: 

203 query = select( 

204 ( 

205 select(true()) 

206 .select_from(ClusterSubscription) 

207 .where(ClusterSubscription.role == ClusterRole.admin) 

208 .where(ClusterSubscription.user_id == user_id) 

209 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

210 ).exists() 

211 ) 

212 return session.execute(query).scalar_one() 

213 

214 

215def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool: 

216 """ 

217 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

218 """ 

219 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

220 query = select( 

221 ( 

222 select(true()) 

223 .select_from(ClusterSubscription) 

224 .where(ClusterSubscription.role == ClusterRole.admin) 

225 .where(ClusterSubscription.user_id == user_id) 

226 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

227 .where(Cluster.is_official_cluster) 

228 .where(Cluster.parent_node_id == subquery.c.id) 

229 ).exists() 

230 ) 

231 return session.execute(query).scalar_one() 

232 

233 

234def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool: 

235 """ 

236 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

237 """ 

238 query = select( 

239 ( 

240 select(true()) 

241 .select_from(ClusterSubscription) 

242 .where(ClusterSubscription.role == ClusterRole.admin) 

243 .where(ClusterSubscription.user_id == user_id) 

244 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

245 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id)) 

246 .where(func.ST_Contains(Node.geom, shape)) 

247 ).exists() 

248 ) 

249 return session.execute(query).scalar_one() 

250 

251 

252def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool: 

253 """ 

254 Returns True if the user's location is geographically contained within the node's boundary. 

255 This is used to check if a user can leave a community - users cannot leave communities 

256 that contain their geographic location. 

257 """ 

258 query = select( 

259 ( 

260 select(true()) 

261 .select_from(User) 

262 .join(Node, func.ST_Contains(Node.geom, User.geom)) 

263 .where(User.id == user_id) 

264 .where(Node.id == node_id) 

265 ).exists() 

266 ) 

267 return session.execute(query).scalar_one() 

268 

269 

270def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None: 

271 tzid = session.execute( 

272 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

273 ).scalar_one_or_none() 

274 return tzid