Coverage for src/couchers/db.py: 77%
116 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-14 00:52 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-14 00:52 +0000
1import functools
2import inspect
3import logging
4import os
5from collections.abc import Generator, Sequence
6from contextlib import contextmanager
7from os import getpid
8from threading import get_ident
9from typing import cast
11from alembic import command
12from alembic.config import Config
13from geoalchemy2 import WKBElement
14from opentelemetry import trace
15from sqlalchemy import Engine, Row, Subquery, create_engine, text
16from sqlalchemy.orm.session import Session
17from sqlalchemy.pool import QueuePool
18from sqlalchemy.sql import and_, func, literal, or_
20from couchers.config import config
21from couchers.constants import SERVER_THREADS, WORKER_THREADS
22from couchers.context import CouchersContext
23from couchers.models import (
24 Cluster,
25 ClusterRole,
26 ClusterSubscription,
27 FriendRelationship,
28 FriendStatus,
29 Node,
30 TimezoneArea,
31 User,
32)
33from couchers.sql import couchers_select as select
35logger = logging.getLogger(__name__)
37tracer = trace.get_tracer(__name__)
40def apply_migrations() -> None:
41 alembic_dir = os.path.dirname(__file__) + "/../.."
42 cwd = os.getcwd()
43 try:
44 os.chdir(alembic_dir)
45 alembic_cfg = Config("alembic.ini")
46 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
47 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
48 command.upgrade(alembic_cfg, "head")
49 finally:
50 os.chdir(cwd)
53@functools.cache
54def _get_base_engine() -> Engine:
55 return create_engine(
56 config["DATABASE_CONNECTION_STRING"],
57 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
58 # connection unexpectedly" errors
59 pool_pre_ping=True,
60 # one connection per thread
61 poolclass=QueuePool,
62 # main threads + a few extra in case
63 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
64 )
67@contextmanager
68def session_scope() -> Generator[Session]:
69 with tracer.start_as_current_span("session_scope") as rollspan:
70 with Session(_get_base_engine()) as session:
71 session.begin()
72 try:
73 if logger.isEnabledFor(logging.DEBUG):
74 try:
75 frame = inspect.stack()[2]
76 filename_line = f"{frame.filename}:{frame.lineno}"
77 except Exception as e:
78 filename_line = "{unknown file}"
79 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
80 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
81 rollspan.set_attribute("db.backend_pid", backend_pid)
82 rollspan.set_attribute("db.filename_line", filename_line)
83 rollspan.set_attribute("rpc.thread", get_ident())
84 rollspan.set_attribute("rpc.pid", getpid())
86 yield session
87 session.commit()
88 except:
89 session.rollback()
90 raise
91 finally:
92 if logger.isEnabledFor(logging.DEBUG):
93 logger.debug(f"SScope: closed {backend_pid=}")
96@contextmanager
97def worker_repeatable_read_session_scope() -> Generator[Session]:
98 """
99 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions,
100 this causes two different connections to be used
102 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
103 background worker, effectively using postgres as a queueing system.
104 """
105 with tracer.start_as_current_span("worker_session_scope") as rollspan:
106 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
107 session.begin()
108 try:
109 if logger.isEnabledFor(logging.DEBUG):
110 try:
111 frame = inspect.stack()[2]
112 filename_line = f"{frame.filename}:{frame.lineno}"
113 except Exception as e:
114 filename_line = "{unknown file}"
115 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one()
116 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
117 rollspan.set_attribute("db.backend_pid", backend_pid)
118 rollspan.set_attribute("db.filename_line", filename_line)
119 rollspan.set_attribute("rpc.thread", get_ident())
120 rollspan.set_attribute("rpc.pid", getpid())
122 yield session
123 session.commit()
124 except:
125 session.rollback()
126 raise
127 finally:
128 if logger.isEnabledFor(logging.DEBUG):
129 logger.debug(f"SScope (worker): closed {backend_pid=}")
132def db_post_fork() -> None:
133 """
134 Fix post-fork issues with sqlalchemy
135 """
136 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
137 _get_base_engine().dispose(close=False)
140def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool:
141 return (
142 session.execute(
143 select(FriendRelationship)
144 .where_users_column_visible(context, FriendRelationship.from_user_id)
145 .where_users_column_visible(context, FriendRelationship.to_user_id)
146 .where(
147 or_(
148 and_(
149 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
150 ),
151 and_(
152 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
153 ),
154 )
155 )
156 .where(FriendRelationship.status == FriendStatus.accepted)
157 ).scalar_one_or_none()
158 is not None
159 )
162def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None:
163 """
164 Finds the smallest node containing the shape.
166 Shape can be any PostGIS geo object, e.g. output from create_coordinate
167 """
169 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
170 # must always be less than its parent Node, so no need to actually traverse the tree!
171 return (
172 session.execute(
173 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1)
174 )
175 .scalars()
176 .one_or_none()
177 )
180def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery:
181 parents = (
182 select(Node.id, Node.parent_node_id, literal(0).label("level"))
183 .where(Node.id == node_id)
184 .cte("parents", recursive=True)
185 )
187 return select(
188 parents.union(
189 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
190 parents, Node.id == parents.c.parent_node_id
191 )
192 )
193 ).subquery()
196def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]:
197 subquery = _get_node_parents_recursive_cte_subquery(node_id)
198 return session.execute(
199 select(subquery, Cluster)
200 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
201 .where(Cluster.is_official_cluster)
202 .order_by(subquery.c.level.desc())
203 ).all()
206def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool:
207 query = select(
208 (
209 select(True)
210 .select_from(ClusterSubscription)
211 .where(ClusterSubscription.role == ClusterRole.admin)
212 .where(ClusterSubscription.user_id == user_id)
213 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
214 ).exists()
215 )
216 return cast(bool, session.execute(query).scalar_one())
219def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool:
220 """
221 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
222 """
223 subquery = _get_node_parents_recursive_cte_subquery(node_id)
224 query = select(
225 (
226 select(True)
227 .select_from(ClusterSubscription)
228 .where(ClusterSubscription.role == ClusterRole.admin)
229 .where(ClusterSubscription.user_id == user_id)
230 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
231 .where(Cluster.is_official_cluster)
232 .where(Cluster.parent_node_id == subquery.c.id)
233 ).exists()
234 )
235 return cast(bool, session.execute(query).scalar_one())
238def can_moderate_at(session: Session, user_id: int, shape: WKBElement) -> bool:
239 """
240 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
241 """
242 query = select(
243 (
244 select(True)
245 .select_from(ClusterSubscription)
246 .where(ClusterSubscription.role == ClusterRole.admin)
247 .where(ClusterSubscription.user_id == user_id)
248 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
249 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id))
250 .where(func.ST_Contains(Node.geom, shape))
251 ).exists()
252 )
253 return cast(bool, session.execute(query).scalar_one())
256def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool:
257 """
258 Returns True if the user's location is geographically contained within the node's boundary.
259 This is used to check if a user can leave a community - users cannot leave communities
260 that contain their geographic location.
261 """
262 query = select(
263 (
264 select(True)
265 .select_from(User)
266 .join(Node, func.ST_Contains(Node.geom, User.geom))
267 .where(User.id == user_id)
268 .where(Node.id == node_id)
269 ).exists()
270 )
271 return cast(bool, session.execute(query).scalar_one())
274def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None:
275 area = session.execute(
276 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
277 ).scalar_one_or_none()
278 if area:
279 return cast(str | None, area.tzid)
280 return None