Coverage for app/backend/src/couchers/db.py: 78%

122 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 09:29 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from collections.abc import Generator, Sequence 

6from contextlib import contextmanager 

7from os import getpid 

8from threading import get_ident 

9from typing import TYPE_CHECKING 

10 

11from alembic import command 

12from alembic.config import Config 

13from geoalchemy2 import WKBElement 

14from opentelemetry import trace 

15from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true 

16from sqlalchemy.dialects import registry 

17from sqlalchemy.orm.session import Session 

18from sqlalchemy.pool import QueuePool 

19from sqlalchemy.sql import and_, func, literal, or_ 

20 

21from couchers.config import config 

22from couchers.constants import SERVER_THREADS 

23from couchers.models import ( 

24 Cluster, 

25 ClusterRole, 

26 ClusterSubscription, 

27 FriendRelationship, 

28 FriendStatus, 

29 Geom, 

30 Node, 

31 TimezoneArea, 

32 User, 

33) 

34from couchers.perf import register_perf_listeners 

35from couchers.sql import where_users_column_visible 

36 

37if TYPE_CHECKING: 

38 from couchers.context import CouchersContext 

39 

40# Register psycopg (psycopg3) as the default driver for postgresql:// URLs 

41# This must happen before any engine is created 

42registry.register("postgresql", "sqlalchemy.dialects.postgresql.psycopg", "PGDialect_psycopg") 

43 

44logger = logging.getLogger(__name__) 

45 

46tracer = trace.get_tracer(__name__) 

47 

48 

49def apply_migrations() -> None: 

50 alembic_dir = os.path.dirname(__file__) + "/../.." 

51 cwd = os.getcwd() 

52 try: 

53 os.chdir(alembic_dir) 

54 alembic_cfg = Config("alembic.ini") 

55 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

56 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

57 command.upgrade(alembic_cfg, "head") 

58 finally: 

59 os.chdir(cwd) 

60 

61 

62@functools.cache 

63def _get_base_engine() -> Engine: 

64 engine = create_engine( 

65 config.DATABASE_CONNECTION_STRING, 

66 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

67 # connection unexpectedly" errors 

68 pool_pre_ping=True, 

69 # one connection per thread 

70 poolclass=QueuePool, 

71 # each process keeps its own pool, so total connections ~= process count * pool_size, kept under postgres 

72 # max_connections. ~2 per thread since a request can hold two connections at once (handler + _store_log). 

73 pool_size=2 * SERVER_THREADS + 4, 

74 max_overflow=0, 

75 ) 

76 register_perf_listeners(engine) 

77 return engine 

78 

79 

80@contextmanager 

81def session_scope() -> Generator[Session]: 

82 with tracer.start_as_current_span("session_scope") as rollspan: 

83 with Session(_get_base_engine()) as session: 

84 session.begin() 

85 try: 

86 if logger.isEnabledFor(logging.DEBUG): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 try: 

88 frame = inspect.stack()[2] 

89 filename_line = f"{frame.filename}:{frame.lineno}" 

90 except Exception as e: 

91 filename_line = "{unknown file}" 

92 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

93 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

94 rollspan.set_attribute("db.backend_pid", backend_pid) 

95 rollspan.set_attribute("db.filename_line", filename_line) 

96 rollspan.set_attribute("rpc.thread", get_ident()) 

97 rollspan.set_attribute("rpc.pid", getpid()) 

98 

99 yield session 

100 session.commit() 

101 except: 

102 session.rollback() 

103 raise 

104 finally: 

105 if logger.isEnabledFor(logging.DEBUG): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true

106 logger.debug(f"SScope: closed {backend_pid=}") 

107 

108 

109@contextmanager 

110def worker_repeatable_read_session_scope() -> Generator[Session]: 

111 """ 

112 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions, 

113 this causes two different connections to be used 

114 

115 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

116 background worker, effectively using postgres as a queueing system. 

117 """ 

118 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

119 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

120 session.begin() 

121 try: 

122 if logger.isEnabledFor(logging.DEBUG): 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 try: 

124 frame = inspect.stack()[2] 

125 filename_line = f"{frame.filename}:{frame.lineno}" 

126 except Exception as e: 

127 filename_line = "{unknown file}" 

128 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

129 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

130 rollspan.set_attribute("db.backend_pid", backend_pid) 

131 rollspan.set_attribute("db.filename_line", filename_line) 

132 rollspan.set_attribute("rpc.thread", get_ident()) 

133 rollspan.set_attribute("rpc.pid", getpid()) 

134 

135 yield session 

136 session.commit() 

137 except: 

138 session.rollback() 

139 raise 

140 finally: 

141 if logger.isEnabledFor(logging.DEBUG): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 logger.debug(f"SScope (worker): closed {backend_pid=}") 

143 

144 

145def db_post_fork() -> None: 

146 """ 

147 Fix post-fork issues with sqlalchemy 

148 """ 

149 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

150 _get_base_engine().dispose(close=False) 

151 

152 

153def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool: 

154 query = select(FriendRelationship) 

155 query = where_users_column_visible(query, context, FriendRelationship.from_user_id) 

156 query = where_users_column_visible(query, context, FriendRelationship.to_user_id) 

157 query = query.where( 

158 or_( 

159 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user), 

160 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id), 

161 ) 

162 ).where(FriendRelationship.status == FriendStatus.accepted) 

163 return session.execute(query).scalar_one_or_none() is not None 

164 

165 

166def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None: 

167 """ 

168 Finds the smallest node containing the shape. 

169 

170 Shape can be any PostGIS geo object, e.g., output from create_coordinate 

171 """ 

172 

173 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

174 # must always be less than its parent Node, so no need to actually traverse the tree! 

175 return ( 

176 session.execute( 

177 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1) 

178 ) 

179 .scalars() 

180 .one_or_none() 

181 ) 

182 

183 

184def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery: 

185 parents = ( 

186 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

187 .where(Node.id == node_id) 

188 .cte("parents", recursive=True) 

189 ) 

190 

191 return select( 

192 parents.union( 

193 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

194 parents, Node.id == parents.c.parent_node_id 

195 ) 

196 ) 

197 ).subquery() 

198 

199 

200def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]: 

201 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

202 return session.execute( 

203 select(subquery, Cluster) 

204 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

205 .where(Cluster.is_official_cluster) 

206 .order_by(subquery.c.level.desc()) 

207 ).all() 

208 

209 

210def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool: 

211 query = select( 

212 ( 

213 select(true()) 

214 .select_from(ClusterSubscription) 

215 .where(ClusterSubscription.role == ClusterRole.admin) 

216 .where(ClusterSubscription.user_id == user_id) 

217 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

218 ).exists() 

219 ) 

220 return session.execute(query).scalar_one() 

221 

222 

223def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool: 

224 """ 

225 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

226 """ 

227 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

228 query = select( 

229 ( 

230 select(true()) 

231 .select_from(ClusterSubscription) 

232 .where(ClusterSubscription.role == ClusterRole.admin) 

233 .where(ClusterSubscription.user_id == user_id) 

234 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

235 .where(Cluster.is_official_cluster) 

236 .where(Cluster.parent_node_id == subquery.c.id) 

237 ).exists() 

238 ) 

239 return session.execute(query).scalar_one() 

240 

241 

242def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool: 

243 """ 

244 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

245 """ 

246 query = select( 

247 ( 

248 select(true()) 

249 .select_from(ClusterSubscription) 

250 .where(ClusterSubscription.role == ClusterRole.admin) 

251 .where(ClusterSubscription.user_id == user_id) 

252 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

253 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id)) 

254 .where(func.ST_Contains(Node.geom, shape)) 

255 ).exists() 

256 ) 

257 return session.execute(query).scalar_one() 

258 

259 

260def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool: 

261 """ 

262 Returns True if the user's location is geographically contained within the node's boundary. 

263 This is used to check if a user can leave a community - users cannot leave communities 

264 that contain their geographic location. 

265 """ 

266 query = select( 

267 ( 

268 select(true()) 

269 .select_from(User) 

270 .join(Node, func.ST_Contains(Node.geom, User.geom)) 

271 .where(User.id == user_id) 

272 .where(Node.id == node_id) 

273 ).exists() 

274 ) 

275 return session.execute(query).scalar_one() 

276 

277 

278def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None: 

279 tzid = session.execute( 

280 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

281 ).scalar_one_or_none() 

282 return tzid