Coverage for app / backend / src / couchers / db.py: 77%

117 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-03 06:18 +0000

1import functools 

2import inspect 

3import logging 

4import os 

5from collections.abc import Generator, Sequence 

6from contextlib import contextmanager 

7from os import getpid 

8from threading import get_ident 

9 

10from alembic import command 

11from alembic.config import Config 

12from geoalchemy2 import WKBElement 

13from opentelemetry import trace 

14from sqlalchemy import Engine, Row, Subquery, create_engine, select, text, true 

15from sqlalchemy.orm.session import Session 

16from sqlalchemy.pool import QueuePool 

17from sqlalchemy.sql import and_, func, literal, or_ 

18 

19from couchers.config import config 

20from couchers.constants import SERVER_THREADS, WORKER_THREADS 

21from couchers.context import CouchersContext 

22from couchers.models import ( 

23 Cluster, 

24 ClusterRole, 

25 ClusterSubscription, 

26 FriendRelationship, 

27 FriendStatus, 

28 Geom, 

29 Node, 

30 TimezoneArea, 

31 User, 

32) 

33from couchers.sql import where_users_column_visible 

34 

35logger = logging.getLogger(__name__) 

36 

37tracer = trace.get_tracer(__name__) 

38 

39 

40def apply_migrations() -> None: 

41 alembic_dir = os.path.dirname(__file__) + "/../.." 

42 cwd = os.getcwd() 

43 try: 

44 os.chdir(alembic_dir) 

45 alembic_cfg = Config("alembic.ini") 

46 # alembic screws up logging config by default, this tells it not to screw it up if being run at startup like this 

47 alembic_cfg.set_main_option("dont_mess_up_logging", "False") 

48 command.upgrade(alembic_cfg, "head") 

49 finally: 

50 os.chdir(cwd) 

51 

52 

53@functools.cache 

54def _get_base_engine() -> Engine: 

55 return create_engine( 

56 config["DATABASE_CONNECTION_STRING"], 

57 # checks that the connections in the pool are alive before using them, which avoids the "server closed the 

58 # connection unexpectedly" errors 

59 pool_pre_ping=True, 

60 # one connection per thread 

61 poolclass=QueuePool, 

62 # main threads + a few extra in case 

63 pool_size=SERVER_THREADS + WORKER_THREADS + 12, 

64 ) 

65 

66 

67@contextmanager 

68def session_scope() -> Generator[Session]: 

69 with tracer.start_as_current_span("session_scope") as rollspan: 

70 with Session(_get_base_engine()) as session: 

71 session.begin() 

72 try: 

73 if logger.isEnabledFor(logging.DEBUG): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 try: 

75 frame = inspect.stack()[2] 

76 filename_line = f"{frame.filename}:{frame.lineno}" 

77 except Exception as e: 

78 filename_line = "{unknown file}" 

79 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

80 logger.debug(f"SScope: got {backend_pid=} at {filename_line}") 

81 rollspan.set_attribute("db.backend_pid", backend_pid) 

82 rollspan.set_attribute("db.filename_line", filename_line) 

83 rollspan.set_attribute("rpc.thread", get_ident()) 

84 rollspan.set_attribute("rpc.pid", getpid()) 

85 

86 yield session 

87 session.commit() 

88 except: 

89 session.rollback() 

90 raise 

91 finally: 

92 if logger.isEnabledFor(logging.DEBUG): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 logger.debug(f"SScope: closed {backend_pid=}") 

94 

95 

96@contextmanager 

97def worker_repeatable_read_session_scope() -> Generator[Session]: 

98 """ 

99 This is a separate session scope that is isolated from the main one since otherwise we end up nesting transactions, 

100 this causes two different connections to be used 

101 

102 This operates in a `REPEATABLE READ` isolation level so that we can do a `SELECT ... FOR UPDATE SKIP LOCKED` in the 

103 background worker, effectively using postgres as a queueing system. 

104 """ 

105 with tracer.start_as_current_span("worker_session_scope") as rollspan: 

106 with Session(_get_base_engine().execution_options(isolation_level="REPEATABLE READ")) as session: 

107 session.begin() 

108 try: 

109 if logger.isEnabledFor(logging.DEBUG): 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 try: 

111 frame = inspect.stack()[2] 

112 filename_line = f"{frame.filename}:{frame.lineno}" 

113 except Exception as e: 

114 filename_line = "{unknown file}" 

115 backend_pid = session.execute(text("SELECT pg_backend_pid();")).scalar_one() 

116 logger.debug(f"SScope (worker): got {backend_pid=} at {filename_line}") 

117 rollspan.set_attribute("db.backend_pid", backend_pid) 

118 rollspan.set_attribute("db.filename_line", filename_line) 

119 rollspan.set_attribute("rpc.thread", get_ident()) 

120 rollspan.set_attribute("rpc.pid", getpid()) 

121 

122 yield session 

123 session.commit() 

124 except: 

125 session.rollback() 

126 raise 

127 finally: 

128 if logger.isEnabledFor(logging.DEBUG): 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 logger.debug(f"SScope (worker): closed {backend_pid=}") 

130 

131 

132def db_post_fork() -> None: 

133 """ 

134 Fix post-fork issues with sqlalchemy 

135 """ 

136 # see https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork 

137 _get_base_engine().dispose(close=False) 

138 

139 

140def are_friends(session: Session, context: CouchersContext, other_user: int) -> bool: 

141 query = select(FriendRelationship) 

142 query = where_users_column_visible(query, context, FriendRelationship.from_user_id) 

143 query = where_users_column_visible(query, context, FriendRelationship.to_user_id) 

144 query = query.where( 

145 or_( 

146 and_(FriendRelationship.from_user_id == context.user_id, FriendRelationship.to_user_id == other_user), 

147 and_(FriendRelationship.from_user_id == other_user, FriendRelationship.to_user_id == context.user_id), 

148 ) 

149 ).where(FriendRelationship.status == FriendStatus.accepted) 

150 return session.execute(query).scalar_one_or_none() is not None 

151 

152 

153def get_parent_node_at_location(session: Session, shape: WKBElement) -> Node | None: 

154 """ 

155 Finds the smallest node containing the shape. 

156 

157 Shape can be any PostGIS geo object, e.g., output from create_coordinate 

158 """ 

159 

160 # Find the lowest Node (in the Node tree) that contains the shape. By construction of nodes, the area of a sub-node 

161 # must always be less than its parent Node, so no need to actually traverse the tree! 

162 return ( 

163 session.execute( 

164 select(Node).where(func.ST_Contains(Node.geom, shape)).order_by(func.ST_Area(Node.geom)).limit(1) 

165 ) 

166 .scalars() 

167 .one_or_none() 

168 ) 

169 

170 

171def _get_node_parents_recursive_cte_subquery(node_id: int) -> Subquery: 

172 parents = ( 

173 select(Node.id, Node.parent_node_id, literal(0).label("level")) 

174 .where(Node.id == node_id) 

175 .cte("parents", recursive=True) 

176 ) 

177 

178 return select( 

179 parents.union( 

180 select(Node.id, Node.parent_node_id, (parents.c.level + 1).label("level")).join( 

181 parents, Node.id == parents.c.parent_node_id 

182 ) 

183 ) 

184 ).subquery() 

185 

186 

187def get_node_parents_recursively(session: Session, node_id: int) -> Sequence[Row[tuple[int, int, int, Cluster]]]: 

188 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

189 return session.execute( 

190 select(subquery, Cluster) 

191 .join(Cluster, Cluster.parent_node_id == subquery.c.id) 

192 .where(Cluster.is_official_cluster) 

193 .order_by(subquery.c.level.desc()) 

194 ).all() 

195 

196 

197def _can_moderate_any_cluster(session: Session, user_id: int, cluster_ids: list[int]) -> bool: 

198 query = select( 

199 ( 

200 select(true()) 

201 .select_from(ClusterSubscription) 

202 .where(ClusterSubscription.role == ClusterRole.admin) 

203 .where(ClusterSubscription.user_id == user_id) 

204 .where(ClusterSubscription.cluster_id.in_(cluster_ids)) 

205 ).exists() 

206 ) 

207 return session.execute(query).scalar_one() 

208 

209 

210def can_moderate_node(session: Session, user_id: int, node_id: int) -> bool: 

211 """ 

212 Returns True if the user_id can moderate the given node (i.e., if they are admin of any community that is a parent of the node) 

213 """ 

214 subquery = _get_node_parents_recursive_cte_subquery(node_id) 

215 query = select( 

216 ( 

217 select(true()) 

218 .select_from(ClusterSubscription) 

219 .where(ClusterSubscription.role == ClusterRole.admin) 

220 .where(ClusterSubscription.user_id == user_id) 

221 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

222 .where(Cluster.is_official_cluster) 

223 .where(Cluster.parent_node_id == subquery.c.id) 

224 ).exists() 

225 ) 

226 return session.execute(query).scalar_one() 

227 

228 

229def can_moderate_at(session: Session, user_id: int, shape: Geom) -> bool: 

230 """ 

231 Returns True if the user_id can moderate a given geo-shape (i.e., if the shape is contained in any Node that the user is an admin of) 

232 """ 

233 query = select( 

234 ( 

235 select(true()) 

236 .select_from(ClusterSubscription) 

237 .where(ClusterSubscription.role == ClusterRole.admin) 

238 .where(ClusterSubscription.user_id == user_id) 

239 .join(Cluster, Cluster.id == ClusterSubscription.cluster_id) 

240 .join(Node, and_(Cluster.is_official_cluster, Node.id == Cluster.parent_node_id)) 

241 .where(func.ST_Contains(Node.geom, shape)) 

242 ).exists() 

243 ) 

244 return session.execute(query).scalar_one() 

245 

246 

247def is_user_in_node_geography(session: Session, user_id: int, node_id: int) -> bool: 

248 """ 

249 Returns True if the user's location is geographically contained within the node's boundary. 

250 This is used to check if a user can leave a community - users cannot leave communities 

251 that contain their geographic location. 

252 """ 

253 query = select( 

254 ( 

255 select(true()) 

256 .select_from(User) 

257 .join(Node, func.ST_Contains(Node.geom, User.geom)) 

258 .where(User.id == user_id) 

259 .where(Node.id == node_id) 

260 ).exists() 

261 ) 

262 return session.execute(query).scalar_one() 

263 

264 

265def timezone_at_coordinate(session: Session, geom: WKBElement) -> str | None: 

266 tzid = session.execute( 

267 select(TimezoneArea.tzid).where(func.ST_Contains(TimezoneArea.geom, geom)) 

268 ).scalar_one_or_none() 

269 return tzid