Coverage for app/backend/src/couchers/db.py: 78%
122 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-21 09:29 +0000
1import functools
2import inspect
3import logging
4import os
5from collections.abc import Generator, Sequence
6from contextlib import contextmanager
7from os import getpid
8from threading import get_ident
9from typing import TYPE_CHECKING
11from alembic import command
12from alembic.config import Config
13from geoalchemy2 import WKBElement
14from opentelemetry import trace
15from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true
16from sqlalchemy.dialects import registry
17from sqlalchemy.orm.session import Session
18from sqlalchemy.pool import QueuePool
19from sqlalchemy.sql import and_, func, literal, or_
21from couchers.config import config
22from couchers.constants import SERVER_THREADS
23from couchers.models import (
24 Cluster,
25 ClusterRole,
26 ClusterSubscription,
27 FriendRelationship,
28 FriendStatus,
29 Geom,
30 Node,
31 TimezoneArea,
32 User,
33)
34from couchers.perf import register_perf_listeners
35from couchers.sql import where_users_column_visible
37if TYPE_CHECKING:
38 from couchers.context import CouchersContext
40# Register psycopg (psycopg3) as the default driver for postgresql:// URLs
41# This must happen before any engine is created
42registry.register("postgresql", "sqlalchemy.dialects.postgresql.psycopg", "PGDialect_psycopg")
44logger = logging.getLogger(__name__)
46tracer = trace.get_tracer(__name__)
49def apply_migrations() -> None:
50 alembic_dir = os.path.dirname(__file__) + "/../.."
51 cwd = os.getcwd()
52 try:
53 os.chdir(alembic_dir)
54 alembic_cfg = Config("alembic.ini")
55 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
56 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
57 command.upgrade(alembic_cfg, "head")
58 finally:
59 os.chdir(cwd)
62@functools.cache
63def _get_base_engine() -> Engine:
64 engine = create_engine(
65 config.DATABASE_CONNECTION_STRING,
66 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
67 # connection unexpectedly" errors
68 pool_pre_ping=True,
69 # one connection per thread
70 poolclass=QueuePool,
71 # each process keeps its own pool, so total connections ~= process count * pool_size, kept under postgres
72 # max_connections. ~2 per thread since a request can hold two connections at once (handler + _store_log).
73 pool_size=2 * SERVER_THREADS + 4,
74 max_overflow=0,
75 )
76 register_perf_listeners(engine)
77 return engine
80@contextmanager
81def session_scope() -> Generator[Session]:
82 with tracer.start_as_current_span("session_scope") as rollspan:
83 with Session(_get_base_engine()) as session:
84 session.begin()
85 try:
86 if logger.isEnabledFor(logging.DEBUG): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 try:
88 frame = inspect.stack()[2]
89 filename_line = f"{frame.filename}:{frame.lineno}"
90 except Exception as e:
91 filename_line = "{unknown file}"
92 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
93 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
94 rollspan.set_attribute("db.backend_pid", backend_pid)
95 rollspan.set_attribute("db.filename_line", filename_line)
96 rollspan.set_attribute("rpc.thread", get_ident())
97 rollspan.set_attribute("rpc.pid", getpid())
99 yield session
100 session.commit()
101 except:
102 session.rollback()
103 raise
104 finally:
105 if logger.isEnabledFor(logging.DEBUG): 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 logger.debug(f"SScope: closed {backend_pid=}")
109@contextmanager
110def worker_repeatable_read_session_scope() -> Generator[Session]:
111 """
112 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions,
113 this causes two different connections to be used
115 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
116 background worker, effectively using postgres as a queueing system.
117 """
118 with tracer.start_as_current_span("worker_session_scope") as rollspan:
119 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
120 session.begin()
121 try:
122 if logger.isEnabledFor(logging.DEBUG): 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 try:
124 frame = inspect.stack()[2]
125 filename_line = f"{frame.filename}:{frame.lineno}"
126 except Exception as e:
127 filename_line = "{unknown file}"
128 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
129 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
130 rollspan.set_attribute("db.backend_pid", backend_pid)
131 rollspan.set_attribute("db.filename_line", filename_line)
132 rollspan.set_attribute("rpc.thread", get_ident())
133 rollspan.set_attribute("rpc.pid", getpid())
135 yield session
136 session.commit()
137 except:
138 session.rollback()
139 raise
140 finally:
141 if logger.isEnabledFor(logging.DEBUG): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 logger.debug(f"SScope (worker): closed {backend_pid=}")
145def db_post_fork() -> None:
146 """
147 Fix post-fork issues with sqlalchemy
148 """
149 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
150 _get_base_engine().dispose(close=False)
153def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool:
154 query = select(FriendRelationship)
155 query = where_users_column_visible(query, context, FriendRelationship.from_user_id)
156 query = where_users_column_visible(query, context, FriendRelationship.to_user_id)
157 query = query.where(
158 or_(
159 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user),
160 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id),
161 )
162 ).where(FriendRelationship.status == FriendStatus.accepted)
163 return session.execute(query).scalar_one_or_none() is not None
166def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None:
167 """
168 Finds the smallest node containing the shape.
170 Shape can be any PostGIS geo object, e.g., output from create_coordinate
171 """
173 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
174 # must always be less than its parent Node, so no need to actually traverse the tree!
175 return (
176 session.execute(
177 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1)
178 )
179 .scalars()
180 .one_or_none()
181 )
184def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery:
185 parents = (
186 select(Node.id, Node.parent_node_id, literal(0).label("level"))
187 .where(Node.id == node_id)
188 .cte("parents", recursive=True)
189 )
191 return select(
192 parents.union(
193 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
194 parents, Node.id == parents.c.parent_node_id
195 )
196 )
197 ).subquery()
200def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]:
201 subquery = _get_node_parents_recursive_cte_subquery(node_id)
202 return session.execute(
203 select(subquery, Cluster)
204 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
205 .where(Cluster.is_official_cluster)
206 .order_by(subquery.c.level.desc())
207 ).all()
210def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool:
211 query = select(
212 (
213 select(true())
214 .select_from(ClusterSubscription)
215 .where(ClusterSubscription.role == ClusterRole.admin)
216 .where(ClusterSubscription.user_id == user_id)
217 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
218 ).exists()
219 )
220 return session.execute(query).scalar_one()
223def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool:
224 """
225 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
226 """
227 subquery = _get_node_parents_recursive_cte_subquery(node_id)
228 query = select(
229 (
230 select(true())
231 .select_from(ClusterSubscription)
232 .where(ClusterSubscription.role == ClusterRole.admin)
233 .where(ClusterSubscription.user_id == user_id)
234 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
235 .where(Cluster.is_official_cluster)
236 .where(Cluster.parent_node_id == subquery.c.id)
237 ).exists()
238 )
239 return session.execute(query).scalar_one()
242def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool:
243 """
244 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
245 """
246 query = select(
247 (
248 select(true())
249 .select_from(ClusterSubscription)
250 .where(ClusterSubscription.role == ClusterRole.admin)
251 .where(ClusterSubscription.user_id == user_id)
252 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
253 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id))
254 .where(func.ST_Contains(Node.geom, shape))
255 ).exists()
256 )
257 return session.execute(query).scalar_one()
260def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool:
261 """
262 Returns True if the user's location is geographically contained within the node's boundary.
263 This is used to check if a user can leave a community - users cannot leave communities
264 that contain their geographic location.
265 """
266 query = select(
267 (
268 select(true())
269 .select_from(User)
270 .join(Node, func.ST_Contains(Node.geom, User.geom))
271 .where(User.id == user_id)
272 .where(Node.id == node_id)
273 ).exists()
274 )
275 return session.execute(query).scalar_one()
278def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None:
279 tzid = session.execute(
280 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
281 ).scalar_one_or_none()
282 return tzid