Coverage for app / backend / src / couchers / db.py: 77%
117 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-03 06:18 +0000
1import functools
2import inspect
3import logging
4import os
5from collections.abc import Generator, Sequence
6from contextlib import contextmanager
7from os import getpid
8from threading import get_ident
10from alembic import command
11from alembic.config import Config
12from geoalchemy2 import WKBElement
13from opentelemetry import trace
14from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true
15from sqlalchemy.orm.session import Session
16from sqlalchemy.pool import QueuePool
17from sqlalchemy.sql import and_, func, literal, or_
19from couchers.config import config
20from couchers.constants import SERVER_THREADS, WORKER_THREADS
21from couchers.context import CouchersContext
22from couchers.models import (
23 Cluster,
24 ClusterRole,
25 ClusterSubscription,
26 FriendRelationship,
27 FriendStatus,
28 Geom,
29 Node,
30 TimezoneArea,
31 User,
32)
33from couchers.sql import where_users_column_visible
35logger = logging.getLogger(__name__)
37tracer = trace.get_tracer(__name__)
40def apply_migrations() -> None:
41 alembic_dir = os.path.dirname(__file__) + "/../.."
42 cwd = os.getcwd()
43 try:
44 os.chdir(alembic_dir)
45 alembic_cfg = Config("alembic.ini")
46 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
47 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
48 command.upgrade(alembic_cfg, "head")
49 finally:
50 os.chdir(cwd)
53@functools.cache
54def _get_base_engine() -> Engine:
55 return create_engine(
56 config["DATABASE_CONNECTION_STRING"],
57 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
58 # connection unexpectedly" errors
59 pool_pre_ping=True,
60 # one connection per thread
61 poolclass=QueuePool,
62 # main threads + a few extra in case
63 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
64 )
67@contextmanager
68def session_scope() -> Generator[Session]:
69 with tracer.start_as_current_span("session_scope") as rollspan:
70 with Session(_get_base_engine()) as session:
71 session.begin()
72 try:
73 if logger.isEnabledFor(logging.DEBUG): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 try:
75 frame = inspect.stack()[2]
76 filename_line = f"{frame.filename}:{frame.lineno}"
77 except Exception as e:
78 filename_line = "{unknown file}"
79 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
80 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
81 rollspan.set_attribute("db.backend_pid", backend_pid)
82 rollspan.set_attribute("db.filename_line", filename_line)
83 rollspan.set_attribute("rpc.thread", get_ident())
84 rollspan.set_attribute("rpc.pid", getpid())
86 yield session
87 session.commit()
88 except:
89 session.rollback()
90 raise
91 finally:
92 if logger.isEnabledFor(logging.DEBUG): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 logger.debug(f"SScope: closed {backend_pid=}")
96@contextmanager
97def worker_repeatable_read_session_scope() -> Generator[Session]:
98 """
99 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions,
100 this causes two different connections to be used
102 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
103 background worker, effectively using postgres as a queueing system.
104 """
105 with tracer.start_as_current_span("worker_session_scope") as rollspan:
106 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
107 session.begin()
108 try:
109 if logger.isEnabledFor(logging.DEBUG): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 try:
111 frame = inspect.stack()[2]
112 filename_line = f"{frame.filename}:{frame.lineno}"
113 except Exception as e:
114 filename_line = "{unknown file}"
115 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
116 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
117 rollspan.set_attribute("db.backend_pid", backend_pid)
118 rollspan.set_attribute("db.filename_line", filename_line)
119 rollspan.set_attribute("rpc.thread", get_ident())
120 rollspan.set_attribute("rpc.pid", getpid())
122 yield session
123 session.commit()
124 except:
125 session.rollback()
126 raise
127 finally:
128 if logger.isEnabledFor(logging.DEBUG): 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 logger.debug(f"SScope (worker): closed {backend_pid=}")
132def db_post_fork() -> None:
133 """
134 Fix post-fork issues with sqlalchemy
135 """
136 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
137 _get_base_engine().dispose(close=False)
140def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool:
141 query = select(FriendRelationship)
142 query = where_users_column_visible(query, context, FriendRelationship.from_user_id)
143 query = where_users_column_visible(query, context, FriendRelationship.to_user_id)
144 query = query.where(
145 or_(
146 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user),
147 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id),
148 )
149 ).where(FriendRelationship.status == FriendStatus.accepted)
150 return session.execute(query).scalar_one_or_none() is not None
153def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None:
154 """
155 Finds the smallest node containing the shape.
157 Shape can be any PostGIS geo object, e.g., output from create_coordinate
158 """
160 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
161 # must always be less than its parent Node, so no need to actually traverse the tree!
162 return (
163 session.execute(
164 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1)
165 )
166 .scalars()
167 .one_or_none()
168 )
171def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery:
172 parents = (
173 select(Node.id, Node.parent_node_id, literal(0).label("level"))
174 .where(Node.id == node_id)
175 .cte("parents", recursive=True)
176 )
178 return select(
179 parents.union(
180 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
181 parents, Node.id == parents.c.parent_node_id
182 )
183 )
184 ).subquery()
187def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]:
188 subquery = _get_node_parents_recursive_cte_subquery(node_id)
189 return session.execute(
190 select(subquery, Cluster)
191 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
192 .where(Cluster.is_official_cluster)
193 .order_by(subquery.c.level.desc())
194 ).all()
197def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool:
198 query = select(
199 (
200 select(true())
201 .select_from(ClusterSubscription)
202 .where(ClusterSubscription.role == ClusterRole.admin)
203 .where(ClusterSubscription.user_id == user_id)
204 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
205 ).exists()
206 )
207 return session.execute(query).scalar_one()
210def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool:
211 """
212 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
213 """
214 subquery = _get_node_parents_recursive_cte_subquery(node_id)
215 query = select(
216 (
217 select(true())
218 .select_from(ClusterSubscription)
219 .where(ClusterSubscription.role == ClusterRole.admin)
220 .where(ClusterSubscription.user_id == user_id)
221 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
222 .where(Cluster.is_official_cluster)
223 .where(Cluster.parent_node_id == subquery.c.id)
224 ).exists()
225 )
226 return session.execute(query).scalar_one()
229def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool:
230 """
231 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
232 """
233 query = select(
234 (
235 select(true())
236 .select_from(ClusterSubscription)
237 .where(ClusterSubscription.role == ClusterRole.admin)
238 .where(ClusterSubscription.user_id == user_id)
239 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
240 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id))
241 .where(func.ST_Contains(Node.geom, shape))
242 ).exists()
243 )
244 return session.execute(query).scalar_one()
247def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool:
248 """
249 Returns True if the user's location is geographically contained within the node's boundary.
250 This is used to check if a user can leave a community - users cannot leave communities
251 that contain their geographic location.
252 """
253 query = select(
254 (
255 select(true())
256 .select_from(User)
257 .join(Node, func.ST_Contains(Node.geom, User.geom))
258 .where(User.id == user_id)
259 .where(Node.id == node_id)
260 ).exists()
261 )
262 return session.execute(query).scalar_one()
265def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None:
266 tzid = session.execute(
267 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
268 ).scalar_one_or_none()
269 return tzid