Coverage for src/couchers/db.py: 78%
116 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-11-21 04:21 +0000
1import functools
2import inspect
3import logging
4import os
5from contextlib import contextmanager
6from os import getpid
7from threading import get_ident
8from time import perf_counter_ns
10from alembic import command
11from alembic.config import Config
12from opentelemetry import trace
13from sqlalchemy import create_engine, event, text
14from sqlalchemy.engine import Engine
15from sqlalchemy.orm.session import Session
16from sqlalchemy.pool import QueuePool
17from sqlalchemy.sql import and_, func, literal, or_
19from couchers.config import config
20from couchers.constants import SERVER_THREADS, WORKER_THREADS
21from couchers.models import (
22 Cluster,
23 ClusterRole,
24 ClusterSubscription,
25 FriendRelationship,
26 FriendStatus,
27 Node,
28 TimezoneArea,
29)
30from couchers.profiler import add_sql_statement
31from couchers.sql import couchers_select as select
32from couchers.utils import now
34logger = logging.getLogger(__name__)
36tracer = trace.get_tracer(__name__)
39def apply_migrations():
40 alembic_dir = os.path.dirname(__file__) + "/../.."
41 cwd = os.getcwd()
42 try:
43 os.chdir(alembic_dir)
44 alembic_cfg = Config("alembic.ini")
45 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
46 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
47 command.upgrade(alembic_cfg, "head")
48 finally:
49 os.chdir(cwd)
52@functools.cache
53def _get_base_engine():
54 return create_engine(
55 config["DATABASE_CONNECTION_STRING"],
56 # checks that the connections in the pool are alive before using them, which avoids the "server closed the
57 # connection unexpectedly" errors
58 pool_pre_ping=True,
59 # one connection per thread
60 poolclass=QueuePool,
61 # main threads + a few extra in case
62 pool_size=SERVER_THREADS + WORKER_THREADS + 12,
63 )
66@contextmanager
67def session_scope():
68 with tracer.start_as_current_span("session_scope") as rollspan:
69 with Session(_get_base_engine()) as session:
70 session.begin()
71 try:
72 if logger.isEnabledFor(logging.DEBUG):
73 try:
74 frame = inspect.stack()[2]
75 filename_line = f"{frame.filename}:{frame.lineno}"
76 except Exception as e:
77 filename_line = "{unknown file}"
78 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
79 logger.debug(f"SScope: got {backend_pid=} at {filename_line}")
80 rollspan.set_attribute("db.backend_pid", backend_pid)
81 rollspan.set_attribute("db.filename_line", filename_line)
82 rollspan.set_attribute("rpc.thread", get_ident())
83 rollspan.set_attribute("rpc.pid", getpid())
85 yield session
86 session.commit()
87 except:
88 session.rollback()
89 raise
90 finally:
91 if logger.isEnabledFor(logging.DEBUG):
92 logger.debug(f"SScope: closed {backend_pid=}")
95@contextmanager
96def worker_repeatable_read_session_scope():
97 """
98 This is a separate sesson scope that is isolated from the main one since otherwise we end up nesting transactions,
99 this causes two different connections to be used
101 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the
102 background worker, effectively using postgres as a queueing system.
103 """
104 with tracer.start_as_current_span("worker_session_scope") as rollspan:
105 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session:
106 session.begin()
107 try:
108 if logger.isEnabledFor(logging.DEBUG):
109 try:
110 frame = inspect.stack()[2]
111 filename_line = f"{frame.filename}:{frame.lineno}"
112 except Exception as e:
113 filename_line = "{unknown file}"
114 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar()
115 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}")
116 rollspan.set_attribute("db.backend_pid", backend_pid)
117 rollspan.set_attribute("db.filename_line", filename_line)
118 rollspan.set_attribute("rpc.thread", get_ident())
119 rollspan.set_attribute("rpc.pid", getpid())
121 yield session
122 session.commit()
123 except:
124 session.rollback()
125 raise
126 finally:
127 if logger.isEnabledFor(logging.DEBUG):
128 logger.debug(f"SScope (worker): closed {backend_pid=}")
131def db_post_fork():
132 """
133 Fix post-fork issues with sqlalchemy
134 """
135 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
136 _get_base_engine().dispose(close=False)
139@event.listens_for(Engine, "before_cursor_execute")
140def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
141 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns()))
144@event.listens_for(Engine, "after_cursor_execute")
145def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
146 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1)
147 end, end_ns = now(), perf_counter_ns()
148 add_sql_statement(statement, parameters, start, start_ns, end, end_ns)
151def are_friends(session, context, other_user):
152 return (
153 session.execute(
154 select(FriendRelationship)
155 .where_users_column_visible(context, FriendRelationship.from_user_id)
156 .where_users_column_visible(context, FriendRelationship.to_user_id)
157 .where(
158 or_(
159 and_(
160 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
161 ),
162 and_(
163 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
164 ),
165 )
166 )
167 .where(FriendRelationship.status == FriendStatus.accepted)
168 ).scalar_one_or_none()
169 is not None
170 )
173def get_parent_node_at_location(session, shape):
174 """
175 Finds the smallest node containing the shape.
177 Shape can be any PostGIS geo object, e.g. output from create_coordinate
178 """
180 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
181 # must always be less than its parent Node, so no need to actually traverse the tree!
182 return (
183 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)))
184 .scalars()
185 .first()
186 )
189def get_node_parents_recursively(session, node_id):
190 """
191 Gets the upwards hierarchy of parents, ordered by level, for a given node
193 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster)
194 """
195 parents = (
196 select(Node.id, Node.parent_node_id, literal(0).label("level"))
197 .where(Node.id == node_id)
198 .cte("parents", recursive=True)
199 )
201 subquery = select(
202 parents.union(
203 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
204 parents, Node.id == parents.c.parent_node_id
205 )
206 )
207 ).subquery()
209 return session.execute(
210 select(subquery, Cluster)
211 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
212 .where(Cluster.is_official_cluster)
213 .order_by(subquery.c.level.desc())
214 ).all()
217def _can_moderate_any_cluster(session, user_id, cluster_ids):
218 return (
219 session.execute(
220 select(func.count())
221 .select_from(ClusterSubscription)
222 .where(ClusterSubscription.role == ClusterRole.admin)
223 .where(ClusterSubscription.user_id == user_id)
224 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
225 ).scalar_one()
226 > 0
227 )
230def can_moderate_at(session, user_id, shape):
231 """
232 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
233 """
234 cluster_ids = [
235 cluster_id
236 for (cluster_id,) in session.execute(
237 select(Cluster.id)
238 .join(Node, Node.id == Cluster.parent_node_id)
239 .where(Cluster.is_official_cluster)
240 .where(func.ST_Contains(Node.geom, shape))
241 ).all()
242 ]
243 return _can_moderate_any_cluster(session, user_id, cluster_ids)
246def can_moderate_node(session, user_id, node_id):
247 """
248 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
249 """
250 return _can_moderate_any_cluster(
251 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)]
252 )
255def timezone_at_coordinate(session, geom):
256 area = session.execute(
257 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
258 ).scalar_one_or_none()
259 if area:
260 return area.tzid
261 return None