Coverage for src/couchers/db.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import functools
2import logging
3import os
4from contextlib import contextmanager
5from time import perf_counter_ns
7from alembic import command
8from alembic.config import Config
9from sqlalchemy import create_engine, event
10from sqlalchemy.engine import Engine
11from sqlalchemy.orm.session import Session
12from sqlalchemy.pool import NullPool, SingletonThreadPool
13from sqlalchemy.sql import and_, func, literal, or_
15from couchers import config
16from couchers.constants import SERVER_THREADS
17from couchers.models import (
18 Cluster,
19 ClusterRole,
20 ClusterSubscription,
21 FriendRelationship,
22 FriendStatus,
23 Node,
24 TimezoneArea,
25)
26from couchers.profiler import add_sql_statement
27from couchers.sql import couchers_select as select
28from couchers.utils import now
30logger = logging.getLogger(__name__)
33def apply_migrations():
34 alembic_dir = os.path.dirname(__file__) + "/../.."
35 cwd = os.getcwd()
36 try:
37 os.chdir(alembic_dir)
38 alembic_cfg = Config("alembic.ini")
39 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this
40 alembic_cfg.set_main_option("dont_mess_up_logging", "False")
41 command.upgrade(alembic_cfg, "head")
42 finally:
43 os.chdir(cwd)
46@functools.lru_cache
47def _get_base_engine():
48 if config.config["IN_TEST"]:
49 pool_opts = {"poolclass": NullPool}
50 else:
51 pool_opts = {
52 # one connection per thread
53 "poolclass": SingletonThreadPool,
54 # main threads + a couple for bg workers, etc
55 "pool_size": SERVER_THREADS + 12,
56 }
58 # `future` enables SQLalchemy 2.0 behaviour
59 # `pool_pre_ping` checks that the connections in the pool are alive before using them, which avoids the "server
60 # closed the connection unexpectedly" errors
61 return create_engine(
62 config.config["DATABASE_CONNECTION_STRING"],
63 future=True,
64 pool_pre_ping=True,
65 **pool_opts,
66 )
69def get_engine(isolation_level=None):
70 """
71 Creates an engine with the given isolation level.
72 """
73 # creates a shallow copy with the given isolation level
74 if not isolation_level:
75 return _get_base_engine()
76 else:
77 return _get_base_engine().execution_options(isolation_level=isolation_level)
80@contextmanager
81def session_scope(isolation_level=None):
82 session = Session(get_engine(isolation_level=isolation_level), future=True)
83 try:
84 yield session
85 session.commit()
86 except:
87 session.rollback()
88 raise
89 finally:
90 session.close()
93@event.listens_for(Engine, "before_cursor_execute")
94def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
95 conn.info.setdefault("query_profiler_info", []).append((statement, parameters, now(), perf_counter_ns()))
98@event.listens_for(Engine, "after_cursor_execute")
99def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
100 statement, parameters, start, start_ns = conn.info["query_profiler_info"].pop(-1)
101 end, end_ns = now(), perf_counter_ns()
102 add_sql_statement(statement, parameters, start, start_ns, end, end_ns)
105def are_friends(session, context, other_user):
106 return (
107 session.execute(
108 select(FriendRelationship)
109 .where_users_column_visible(context, FriendRelationship.from_user_id)
110 .where_users_column_visible(context, FriendRelationship.to_user_id)
111 .where(
112 or_(
113 and_(
114 FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user
115 ),
116 and_(
117 FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id
118 ),
119 )
120 )
121 .where(FriendRelationship.status == FriendStatus.accepted)
122 ).scalar_one_or_none()
123 is not None
124 )
127def get_parent_node_at_location(session, shape):
128 """
129 Finds the smallest node containing the shape.
131 Shape can be any PostGIS geo object, e.g. output from create_coordinate
132 """
134 # Fin the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node
135 # must always be less than its parent Node, so no need to actually traverse the tree!
136 return (
137 session.execute(select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)))
138 .scalars()
139 .first()
140 )
143def get_node_parents_recursively(session, node_id):
144 """
145 Gets the upwards hierarchy of parents, ordered by level, for a given node
147 Returns SQLAlchemy rows of (node_id, parent_node_id, level, cluster)
148 """
149 parents = (
150 select(Node.id, Node.parent_node_id, literal(0).label("level"))
151 .where(Node.id == node_id)
152 .cte("parents", recursive=True)
153 )
155 subquery = select(
156 parents.union(
157 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join(
158 parents, Node.id == parents.c.parent_node_id
159 )
160 )
161 ).subquery()
163 return session.execute(
164 select(subquery, Cluster)
165 .join(Cluster, Cluster.parent_node_id == subquery.c.id)
166 .where(Cluster.is_official_cluster)
167 .order_by(subquery.c.level.desc())
168 ).all()
171def _can_moderate_any_cluster(session, user_id, cluster_ids):
172 return (
173 session.execute(
174 select(func.count())
175 .select_from(ClusterSubscription)
176 .where(ClusterSubscription.role == ClusterRole.admin)
177 .where(ClusterSubscription.user_id == user_id)
178 .where(ClusterSubscription.cluster_id.in_(cluster_ids))
179 ).scalar_one()
180 > 0
181 )
184def can_moderate_at(session, user_id, shape):
185 """
186 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of)
187 """
188 cluster_ids = [
189 cluster_id
190 for (cluster_id,) in session.execute(
191 select(Cluster.id)
192 .join(Node, Node.id == Cluster.parent_node_id)
193 .where(Cluster.is_official_cluster)
194 .where(func.ST_Contains(Node.geom, shape))
195 ).all()
196 ]
197 return _can_moderate_any_cluster(session, user_id, cluster_ids)
200def can_moderate_node(session, user_id, node_id):
201 """
202 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node)
203 """
204 return _can_moderate_any_cluster(
205 session, user_id, [cluster.id for _, _, _, cluster in get_node_parents_recursively(session, node_id)]
206 )
209def timezone_at_coordinate(session, geom):
210 area = session.execute(
211 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom))
212 ).scalar_one_or_none()
213 if area:
214 return area.tzid
215 return None