Coverage for src/couchers/db.py: 75%
104 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-20 18:03 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-20 18:03 +0000
1import functools
2import inspect
3import logging
4import os
5from contextlib import contextmanager
6from os import getpid
7from threading import get_ident
9from alembic import command
10from alembic.config import Config
11from opentelemetry import trace
12from sqlalchemy import create_engine, text
13from sqlalchemy.orm.session import Session
14from sqlalchemy.pool import QueuePool
15from sqlalchemy.sql import and_, func, literal, or_
17from couchers.config import config
18from couchers.constants import SERVER_THREADS, WORKER_THREADS
19from couchers.models import (
20 Cluster,
21 ClusterRole,
22 ClusterSubscription,
23 FriendRelationship,
24 FriendStatus,
25 Node,
26 TimezoneArea,
27)
28from couchers.sql import couchers_select as select
30logger = logging.getLogger(__name__)
32tracer = trace.get_tracer(__name__)
35def apply_migrations():
36 alembic_dir = os.path.dirname(__file__) + "/../.."
37 cwd = os.getcwd()
38 try:
39 os.chdir(alembic_dir)
40 alembic_cfg = Config("alembic.ini")
41 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
42 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
43 command.upgrade(alembic_cfg, "head")
44 finally:
45 os.chdir(cwd)
48@functools.cache
49def _get_base_engine():
50 return create_engine(
51 config["DATABASE_CONNECTION_STRING"],
52 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
53 # connection unexpectedly" errors
54 pool_pre_ping=True,
55 # one connection per thread
56 poolclass=QueuePool,
57 # main threads + a few extra in case
58 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
59 )
62@contextmanager
63def session_scope():
64 with tracer.start_as_current_span("session_scope") as rollspan:
65 with Session(_get_base_engine()) as session:
66 session.begin()
67 try:
68 if logger.isEnabledFor(logging.DEBUG):
69 try:
70 frame = inspect.stack()[2]
71 filename_line = f"{frame.filename}:{frame.lineno}"
72 except Exception as e:
73 filename_line = "{unknown file}"
74 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
75 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
76 rollspan.set_attribute("db.backend_pid", backend_pid)
77 rollspan.set_attribute("db.filename_line", filename_line)
78 rollspan.set_attribute("rpc.thread", get_ident())
79 rollspan.set_attribute("rpc.pid", getpid())
81 yield session
82 session.commit()
83 except:
84 session.rollback()
85 raise
86 finally:
87 if logger.isEnabledFor(logging.DEBUG):
88 logger.debug(f"SScope: closed {backend_pid=}")
91@contextmanager
92def worker_repeatable_read_session_scope():
93 """
94 This is a separate sesson scope that is isolated from the main one since otherwise we end up nesting transactions,
95 this causes two different connections to be used
97 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
98 background worker, effectively using postgres as a queueing system.
99 """
100 with tracer.start_as_current_span("worker_session_scope") as rollspan:
101 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
102 session.begin()
103 try:
104 if logger.isEnabledFor(logging.DEBUG):
105 try:
106 frame = inspect.stack()[2]
107 filename_line = f"{frame.filename}:{frame.lineno}"
108 except Exception as e:
109 filename_line = "{unknown file}"
110 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
111 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
112 rollspan.set_attribute("db.backend_pid", backend_pid)
113 rollspan.set_attribute("db.filename_line", filename_line)
114 rollspan.set_attribute("rpc.thread", get_ident())
115 rollspan.set_attribute("rpc.pid", getpid())
117 yield session
118 session.commit()
119 except:
120 session.rollback()
121 raise
122 finally:
123 if logger.isEnabledFor(logging.DEBUG):
124 logger.debug(f"SScope (worker): closed {backend_pid=}")
127def db_post_fork():
128 """
129 Fix post-fork issues with sqlalchemy
130 """
131 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
132 _get_base_engine().dispose(close=False)
135def are_friends(session, context, other_user):
136 return (
137 session.execute(
138 select(FriendRelationship)
139 .where_users_column_visible(context, FriendRelationship.from_user_id)
140 .where_users_column_visible(context, FriendRelationship.to_user_id)
141 .where(
142 or_(
143 and_(
144 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
145 ),
146 and_(
147 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
148 ),
149 )
150 )
151 .where(FriendRelationship.status == FriendStatus.accepted)
152 ).scalar_one_or_none()
153 is not None
154 )
157def get_parent_node_at_location(session, shape):
158 """
159 Finds the smallest node containing the shape.
161 Shape can be any PostGIS geo object, e.g. output from create_coordinate
162 """
164 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
165 # must always be less than its parent Node, so no need to actually traverse the tree!
166 return (
167 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)))
168 .scalars()
169 .first()
170 )
173def get_node_parents_recursively(session, node_id):
174 """
175 Gets the upwards hierarchy of parents, ordered by level, for a given node
177 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster)
178 """
179 parents = (
180 select(Node.id, Node.parent_node_id, literal(0).label("level"))
181 .where(Node.id == node_id)
182 .cte("parents", recursive=True)
183 )
185 subquery = select(
186 parents.union(
187 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
188 parents, Node.id == parents.c.parent_node_id
189 )
190 )
191 ).subquery()
193 return session.execute(
194 select(subquery, Cluster)
195 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
196 .where(Cluster.is_official_cluster)
197 .order_by(subquery.c.level.desc())
198 ).all()
201def _can_moderate_any_cluster(session, user_id, cluster_ids):
202 return (
203 session.execute(
204 select(func.count())
205 .select_from(ClusterSubscription)
206 .where(ClusterSubscription.role == ClusterRole.admin)
207 .where(ClusterSubscription.user_id == user_id)
208 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
209 ).scalar_one()
210 > 0
211 )
214def can_moderate_at(session, user_id, shape):
215 """
216 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
217 """
218 cluster_ids = [
219 cluster_id
220 for (cluster_id,) in session.execute(
221 select(Cluster.id)
222 .join(Node, Node.id == Cluster.parent_node_id)
223 .where(Cluster.is_official_cluster)
224 .where(func.ST_Contains(Node.geom, shape))
225 ).all()
226 ]
227 return _can_moderate_any_cluster(session, user_id, cluster_ids)
230def can_moderate_node(session, user_id, node_id):
231 """
232 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
233 """
234 return _can_moderate_any_cluster(
235 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)]
236 )
239def timezone_at_coordinate(session, geom):
240 area = session.execute(
241 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
242 ).scalar_one_or_none()
243 if area:
244 return area.tzid
245 return None