Coverage for app / backend / src / couchers / db.py: 77%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 09:44 +0000
1import functools
2import inspect
3import logging
4import os
5from collections.abc import Generator, Sequence
6from contextlib import contextmanager
7from os import getpid
8from threading import get_ident
10from alembic import command
11from alembic.config import Config
12from geoalchemy2 import WKBElement
13from opentelemetry import trace
14from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true
15from sqlalchemy.dialects import registry
16from sqlalchemy.orm.session import Session
17from sqlalchemy.pool import QueuePool
18from sqlalchemy.sql import and_, func, literal, or_
20from couchers.config import config
21from couchers.constants import SERVER_THREADS, WORKER_THREADS
22from couchers.context import CouchersContext
23from couchers.models import (
24 Cluster,
25 ClusterRole,
26 ClusterSubscription,
27 FriendRelationship,
28 FriendStatus,
29 Geom,
30 Node,
31 TimezoneArea,
32 User,
33)
34from couchers.sql import where_users_column_visible
36# Register psycopg (psycopg3) as the default driver for postgresql:// URLs
37# This must happen before any engine is created
38registry.register("postgresql", "sqlalchemy.dialects.postgresql.psycopg", "PGDialect_psycopg")
40logger = logging.getLogger(__name__)
42tracer = trace.get_tracer(__name__)
45def apply_migrations() -> None:
46 alembic_dir = os.path.dirname(__file__) + "/../.."
47 cwd = os.getcwd()
48 try:
49 os.chdir(alembic_dir)
50 alembic_cfg = Config("alembic.ini")
51 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
52 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
53 command.upgrade(alembic_cfg, "head")
54 finally:
55 os.chdir(cwd)
58@functools.cache
59def _get_base_engine() -> Engine:
60 return create_engine(
61 config["DATABASE_CONNECTION_STRING"],
62 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
63 # connection unexpectedly" errors
64 pool_pre_ping=True,
65 # one connection per thread
66 poolclass=QueuePool,
67 # main threads + a few extra in case
68 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
69 )
72@contextmanager
73def session_scope() -> Generator[Session]:
74 with tracer.start_as_current_span("session_scope") as rollspan:
75 with Session(_get_base_engine()) as session:
76 session.begin()
77 try:
78 if logger.isEnabledFor(logging.DEBUG): 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 try:
80 frame = inspect.stack()[2]
81 filename_line = f"{frame.filename}:{frame.lineno}"
82 except Exception as e:
83 filename_line = "{unknown file}"
84 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
85 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
86 rollspan.set_attribute("db.backend_pid", backend_pid)
87 rollspan.set_attribute("db.filename_line", filename_line)
88 rollspan.set_attribute("rpc.thread", get_ident())
89 rollspan.set_attribute("rpc.pid", getpid())
91 yield session
92 session.commit()
93 except:
94 session.rollback()
95 raise
96 finally:
97 if logger.isEnabledFor(logging.DEBUG): 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 logger.debug(f"SScope: closed {backend_pid=}")
101@contextmanager
102def worker_repeatable_read_session_scope() -> Generator[Session]:
103 """
104 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions,
105 this causes two different connections to be used
107 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
108 background worker, effectively using postgres as a queueing system.
109 """
110 with tracer.start_as_current_span("worker_session_scope") as rollspan:
111 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
112 session.begin()
113 try:
114 if logger.isEnabledFor(logging.DEBUG): 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 try:
116 frame = inspect.stack()[2]
117 filename_line = f"{frame.filename}:{frame.lineno}"
118 except Exception as e:
119 filename_line = "{unknown file}"
120 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
121 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
122 rollspan.set_attribute("db.backend_pid", backend_pid)
123 rollspan.set_attribute("db.filename_line", filename_line)
124 rollspan.set_attribute("rpc.thread", get_ident())
125 rollspan.set_attribute("rpc.pid", getpid())
127 yield session
128 session.commit()
129 except:
130 session.rollback()
131 raise
132 finally:
133 if logger.isEnabledFor(logging.DEBUG): 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true
134 logger.debug(f"SScope (worker): closed {backend_pid=}")
137def db_post_fork() -> None:
138 """
139 Fix post-fork issues with sqlalchemy
140 """
141 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
142 _get_base_engine().dispose(close=False)
145def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool:
146 query = select(FriendRelationship)
147 query = where_users_column_visible(query, context, FriendRelationship.from_user_id)
148 query = where_users_column_visible(query, context, FriendRelationship.to_user_id)
149 query = query.where(
150 or_(
151 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user),
152 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id),
153 )
154 ).where(FriendRelationship.status == FriendStatus.accepted)
155 return session.execute(query).scalar_one_or_none() is not None
158def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None:
159 """
160 Finds the smallest node containing the shape.
162 Shape can be any PostGIS geo object, e.g., output from create_coordinate
163 """
165 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
166 # must always be less than its parent Node, so no need to actually traverse the tree!
167 return (
168 session.execute(
169 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1)
170 )
171 .scalars()
172 .one_or_none()
173 )
176def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery:
177 parents = (
178 select(Node.id, Node.parent_node_id, literal(0).label("level"))
179 .where(Node.id == node_id)
180 .cte("parents", recursive=True)
181 )
183 return select(
184 parents.union(
185 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
186 parents, Node.id == parents.c.parent_node_id
187 )
188 )
189 ).subquery()
192def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]:
193 subquery = _get_node_parents_recursive_cte_subquery(node_id)
194 return session.execute(
195 select(subquery, Cluster)
196 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
197 .where(Cluster.is_official_cluster)
198 .order_by(subquery.c.level.desc())
199 ).all()
202def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool:
203 query = select(
204 (
205 select(true())
206 .select_from(ClusterSubscription)
207 .where(ClusterSubscription.role == ClusterRole.admin)
208 .where(ClusterSubscription.user_id == user_id)
209 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
210 ).exists()
211 )
212 return session.execute(query).scalar_one()
215def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool:
216 """
217 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
218 """
219 subquery = _get_node_parents_recursive_cte_subquery(node_id)
220 query = select(
221 (
222 select(true())
223 .select_from(ClusterSubscription)
224 .where(ClusterSubscription.role == ClusterRole.admin)
225 .where(ClusterSubscription.user_id == user_id)
226 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
227 .where(Cluster.is_official_cluster)
228 .where(Cluster.parent_node_id == subquery.c.id)
229 ).exists()
230 )
231 return session.execute(query).scalar_one()
234def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool:
235 """
236 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
237 """
238 query = select(
239 (
240 select(true())
241 .select_from(ClusterSubscription)
242 .where(ClusterSubscription.role == ClusterRole.admin)
243 .where(ClusterSubscription.user_id == user_id)
244 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
245 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id))
246 .where(func.ST_Contains(Node.geom, shape))
247 ).exists()
248 )
249 return session.execute(query).scalar_one()
252def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool:
253 """
254 Returns True if the user's location is geographically contained within the node's boundary.
255 This is used to check if a user can leave a community - users cannot leave communities
256 that contain their geographic location.
257 """
258 query = select(
259 (
260 select(true())
261 .select_from(User)
262 .join(Node, func.ST_Contains(Node.geom, User.geom))
263 .where(User.id == user_id)
264 .where(Node.id == node_id)
265 ).exists()
266 )
267 return session.execute(query).scalar_one()
270def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None:
271 tzid = session.execute(
272 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
273 ).scalar_one_or_none()
274 return tzid