Coverage for src/couchers/db.py: 75%
106 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-07-09 00:05 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-07-09 00:05 +0000
1import functools
2import inspect
3import logging
4import os
5from contextlib import contextmanager
6from os import getpid
7from threading import get_ident
9from alembic import command
10from alembic.config import Config
11from opentelemetry import trace
12from sqlalchemy import create_engine, text
13from sqlalchemy.orm.session import Session
14from sqlalchemy.pool import QueuePool
15from sqlalchemy.sql import and_, func, literal, or_
17from couchers.config import config
18from couchers.constants import SERVER_THREADS, WORKER_THREADS
19from couchers.models import (
20 Cluster,
21 ClusterRole,
22 ClusterSubscription,
23 FriendRelationship,
24 FriendStatus,
25 Node,
26 TimezoneArea,
27)
28from couchers.sql import couchers_select as select
30logger = logging.getLogger(__name__)
32tracer = trace.get_tracer(__name__)
35def apply_migrations():
36 alembic_dir = os.path.dirname(__file__) + "/../.."
37 cwd = os.getcwd()
38 try:
39 os.chdir(alembic_dir)
40 alembic_cfg = Config("alembic.ini")
41 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
42 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
43 command.upgrade(alembic_cfg, "head")
44 finally:
45 os.chdir(cwd)
48@functools.cache
49def _get_base_engine():
50 return create_engine(
51 config["DATABASE_CONNECTION_STRING"],
52 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
53 # connection unexpectedly" errors
54 pool_pre_ping=True,
55 # one connection per thread
56 poolclass=QueuePool,
57 # main threads + a few extra in case
58 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
59 )
62@contextmanager
63def session_scope():
64 with tracer.start_as_current_span("session_scope") as rollspan:
65 with Session(_get_base_engine()) as session:
66 session.begin()
67 try:
68 if logger.isEnabledFor(logging.DEBUG):
69 try:
70 frame = inspect.stack()[2]
71 filename_line = f"{frame.filename}:{frame.lineno}"
72 except Exception as e:
73 filename_line = "{unknown file}"
74 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
75 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
76 rollspan.set_attribute("db.backend_pid", backend_pid)
77 rollspan.set_attribute("db.filename_line", filename_line)
78 rollspan.set_attribute("rpc.thread", get_ident())
79 rollspan.set_attribute("rpc.pid", getpid())
81 yield session
82 session.commit()
83 except:
84 session.rollback()
85 raise
86 finally:
87 if logger.isEnabledFor(logging.DEBUG):
88 logger.debug(f"SScope: closed {backend_pid=}")
91@contextmanager
92def worker_repeatable_read_session_scope():
93 """
94 This is a separate sesson scope that is isolated from the main one since otherwise we end up nesting transactions,
95 this causes two different connections to be used
97 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
98 background worker, effectively using postgres as a queueing system.
99 """
100 with tracer.start_as_current_span("worker_session_scope") as rollspan:
101 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
102 session.begin()
103 try:
104 if logger.isEnabledFor(logging.DEBUG):
105 try:
106 frame = inspect.stack()[2]
107 filename_line = f"{frame.filename}:{frame.lineno}"
108 except Exception as e:
109 filename_line = "{unknown file}"
110 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
111 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
112 rollspan.set_attribute("db.backend_pid", backend_pid)
113 rollspan.set_attribute("db.filename_line", filename_line)
114 rollspan.set_attribute("rpc.thread", get_ident())
115 rollspan.set_attribute("rpc.pid", getpid())
117 yield session
118 session.commit()
119 except:
120 session.rollback()
121 raise
122 finally:
123 if logger.isEnabledFor(logging.DEBUG):
124 logger.debug(f"SScope (worker): closed {backend_pid=}")
127def db_post_fork():
128 """
129 Fix post-fork issues with sqlalchemy
130 """
131 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
132 _get_base_engine().dispose(close=False)
135def are_friends(session, context, other_user):
136 return (
137 session.execute(
138 select(FriendRelationship)
139 .where_users_column_visible(context, FriendRelationship.from_user_id)
140 .where_users_column_visible(context, FriendRelationship.to_user_id)
141 .where(
142 or_(
143 and_(
144 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
145 ),
146 and_(
147 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
148 ),
149 )
150 )
151 .where(FriendRelationship.status == FriendStatus.accepted)
152 ).scalar_one_or_none()
153 is not None
154 )
157def get_parent_node_at_location(session, shape):
158 """
159 Finds the smallest node containing the shape.
161 Shape can be any PostGIS geo object, e.g. output from create_coordinate
162 """
164 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
165 # must always be less than its parent Node, so no need to actually traverse the tree!
166 return (
167 session.execute(
168 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1)
169 )
170 .scalars()
171 .one_or_none()
172 )
175def _get_node_parents_recursive_cte_subquery(session, node_id):
176 parents = (
177 select(Node.id, Node.parent_node_id, literal(0).label("level"))
178 .where(Node.id == node_id)
179 .cte("parents", recursive=True)
180 )
182 return select(
183 parents.union(
184 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
185 parents, Node.id == parents.c.parent_node_id
186 )
187 )
188 ).subquery()
191def get_node_parents_recursively(session, node_id):
192 subquery = _get_node_parents_recursive_cte_subquery(session, node_id)
193 return session.execute(
194 select(subquery, Cluster)
195 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
196 .where(Cluster.is_official_cluster)
197 .order_by(subquery.c.level.desc())
198 ).all()
201def _can_moderate_any_cluster(session, user_id, cluster_ids):
202 return session.execute(
203 select(
204 (
205 select(True)
206 .select_from(ClusterSubscription)
207 .where(ClusterSubscription.role == ClusterRole.admin)
208 .where(ClusterSubscription.user_id == user_id)
209 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
210 ).exists()
211 )
212 ).scalar_one()
215def can_moderate_node(session, user_id, node_id):
216 """
217 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
218 """
219 subquery = _get_node_parents_recursive_cte_subquery(session, node_id)
220 return session.execute(
221 select(
222 (
223 select(True)
224 .select_from(ClusterSubscription)
225 .where(ClusterSubscription.role == ClusterRole.admin)
226 .where(ClusterSubscription.user_id == user_id)
227 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
228 .where(Cluster.is_official_cluster)
229 .where(Cluster.parent_node_id == subquery.c.id)
230 ).exists()
231 )
232 ).scalar_one()
234 return _can_moderate_any_cluster(
235 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)]
236 )
239def can_moderate_at(session, user_id, shape):
240 """
241 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
242 """
243 return session.execute(
244 select(
245 (
246 select(True)
247 .select_from(ClusterSubscription)
248 .where(ClusterSubscription.role == ClusterRole.admin)
249 .where(ClusterSubscription.user_id == user_id)
250 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id)
251 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id))
252 .where(func.ST_Contains(Node.geom, shape))
253 ).exists()
254 )
255 ).scalar_one()
258def timezone_at_coordinate(session, geom):
259 area = session.execute(
260 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
261 ).scalar_one_or_none()
262 if area:
263 return area.tzid
264 return None