Coverage for src/couchers/db.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import functools
2import logging
3import os
4from contextlib import contextmanager
5from time import perf_counter_ns
7from alembic import command
8from alembic.config import Config
9from sqlalchemy import create_engine, event
10from sqlalchemy.engine import Engine
11from sqlalchemy.orm.session import Session
12from sqlalchemy.pool import NullPool, SingletonThreadPool
13from sqlalchemy.sql import and_, func, literal, or_
15from couchers import config
16from couchers.constants import SERVER_THREADS
17from couchers.models import (
18 Cluster,
19 ClusterRole,
20 ClusterSubscription,
21 FriendRelationship,
22 FriendStatus,
23 Node,
24 TimezoneArea,
25)
26from couchers.profiler import add_sql_statement
27from couchers.sql import couchers_select as select
28from couchers.utils import now
30logger = logging.getLogger(__name__)
33def apply_migrations():
34 alembic_dir = os.path.dirname(__file__) + "/../.."
35 cwd = os.getcwd()
36 try:
37 os.chdir(alembic_dir)
38 alembic_cfg = Config("alembic.ini")
39 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
40 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
41 command.upgrade(alembic_cfg, "head")
42 finally:
43 os.chdir(cwd)
46@functools.lru_cache
47def _get_base_engine():
48 if config.config["IN_TEST"]:
49 pool_opts = {"poolclass": NullPool}
50 else:
51 pool_opts = {
52 # one connection per thread
53 "poolclass": SingletonThreadPool,
54 # main threads + a couple for bg workers, etc
55 "pool_size": SERVER_THREADS + 12,
56 }
58 # `future` enables SQLalchemy 2.0 behaviour
59 # `pool_pre_ping` checks that the connections in the pool are alive before using them, which avoids the "server
60 # closed the connection unexpectedly" errors
61 return create_engine(
62 config.config["DATABASE_CONNECTION_STRING"],
63 future=True,
64 pool_pre_ping=True,
65 **pool_opts,
66 )
69def clear_base_engine_cache():
70 """
71 This needs to be done when the public schema is dropped.
72 """
73 _get_base_engine.cache_clear()
76def get_engine(isolation_level=None):
77 """
78 Creates an engine with the given isolation level.
79 """
80 # creates a shallow copy with the given isolation level
81 if not isolation_level:
82 return _get_base_engine()
83 else:
84 return _get_base_engine().execution_options(isolation_level=isolation_level)
87@contextmanager
88def session_scope(isolation_level=None):
89 session = Session(get_engine(isolation_level=isolation_level), future=True)
90 try:
91 yield session
92 session.commit()
93 except:
94 session.rollback()
95 raise
96 finally:
97 session.close()
100@event.listens_for(Engine, "before_cursor_execute")
101def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
102 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns()))
105@event.listens_for(Engine, "after_cursor_execute")
106def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
107 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1)
108 end, end_ns = now(), perf_counter_ns()
109 add_sql_statement(statement, parameters, start, start_ns, end, end_ns)
112def are_friends(session, context, other_user):
113 return (
114 session.execute(
115 select(FriendRelationship)
116 .where_users_column_visible(context, FriendRelationship.from_user_id)
117 .where_users_column_visible(context, FriendRelationship.to_user_id)
118 .where(
119 or_(
120 and_(
121 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
122 ),
123 and_(
124 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
125 ),
126 )
127 )
128 .where(FriendRelationship.status == FriendStatus.accepted)
129 ).scalar_one_or_none()
130 is not None
131 )
134def get_parent_node_at_location(session, shape):
135 """
136 Finds the smallest node containing the shape.
138 Shape can be any PostGIS geo object, e.g. output from create_coordinate
139 """
141 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
142 # must always be less than its parent Node, so no need to actually traverse the tree!
143 return (
144 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)))
145 .scalars()
146 .first()
147 )
150def get_node_parents_recursively(session, node_id):
151 """
152 Gets the upwards hierarchy of parents, ordered by level, for a given node
154 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster)
155 """
156 parents = (
157 select(Node.id, Node.parent_node_id, literal(0).label("level"))
158 .where(Node.id == node_id)
159 .cte("parents", recursive=True)
160 )
162 subquery = select(
163 parents.union(
164 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
165 parents, Node.id == parents.c.parent_node_id
166 )
167 )
168 ).subquery()
170 return session.execute(
171 select(subquery, Cluster)
172 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
173 .where(Cluster.is_official_cluster)
174 .order_by(subquery.c.level.desc())
175 ).all()
178def _can_moderate_any_cluster(session, user_id, cluster_ids):
179 return (
180 session.execute(
181 select(func.count())
182 .select_from(ClusterSubscription)
183 .where(ClusterSubscription.role == ClusterRole.admin)
184 .where(ClusterSubscription.user_id == user_id)
185 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
186 ).scalar_one()
187 > 0
188 )
191def can_moderate_at(session, user_id, shape):
192 """
193 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
194 """
195 cluster_ids = [
196 cluster_id
197 for (cluster_id,) in session.execute(
198 select(Cluster.id)
199 .join(Node, Node.id == Cluster.parent_node_id)
200 .where(Cluster.is_official_cluster)
201 .where(func.ST_Contains(Node.geom, shape))
202 ).all()
203 ]
204 return _can_moderate_any_cluster(session, user_id, cluster_ids)
207def can_moderate_node(session, user_id, node_id):
208 """
209 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
210 """
211 return _can_moderate_any_cluster(
212 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)]
213 )
216def timezone_at_coordinate(session, geom):
217 area = session.execute(
218 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
219 ).scalar_one_or_none()
220 if area:
221 return area.tzid
222 return None