Coverage for src/couchers/resources.py: 94%
66 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-20 11:53 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-20 11:53 +0000
1import functools
2import json
3import logging
4from pathlib import Path
5from typing import Any, cast
7from sqlalchemy.orm import Session
8from sqlalchemy.sql import delete, text
10from couchers.config import config
11from couchers.db import session_scope
12from couchers.models import Language, Region, TimezoneArea
13from couchers.sql import couchers_select as select
15logger = logging.getLogger(__name__)
17resources_folder = Path(__file__).parent / ".." / ".." / "resources"
20@functools.cache
21def get_terms_of_service() -> str:
22 """
23 Get the latest terms of service
24 """
25 with open(resources_folder / "terms_of_service.md", "r") as f:
26 return f.read()
29@functools.cache
30def get_icon(name: str) -> str:
31 """
32 Get an icon SVG by name
33 """
34 return (resources_folder / "icons" / name).read_text()
37@functools.cache
38def get_region_dict() -> dict[str, str]:
39 """
40 Get a list of allowed regions as a dictionary of {alpha3: name}.
41 """
42 with session_scope() as session:
43 return {region.code: region.name for region in session.execute(select(Region)).scalars().all()}
46def region_is_allowed(code: str) -> bool:
47 """
48 Check a region code is valid
49 """
50 return code in get_region_dict()
53@functools.cache
54def get_language_dict() -> dict[str, str]:
55 """
56 Get a list of allowed languages as a dictionary of {code: name}.
57 """
58 with session_scope() as session:
59 return {language.code: language.name for language in session.execute(select(Language)).scalars().all()}
62@functools.cache
63def get_badge_data() -> dict[str, Any]:
64 """
65 Get a list of profile badges in form {id: Badge}
66 """
67 with open(resources_folder / "badges.json", "r") as f:
68 data = json.load(f)
69 return cast(dict[str, Any], data)
72@functools.cache
73def get_badge_dict() -> dict[str, dict[str, Any]]:
74 """
75 Get a list of profile badges in form {id: Badge}
76 """
77 return {badge["id"]: badge for badge in get_badge_data()["badges"]}
80@functools.cache
81def get_static_badge_dict() -> dict[str, list[int]]:
82 """
83 Get a list of static badges in form {id: list(user_ids)}
84 """
85 data = get_badge_data()["static_badges"]
86 return cast(dict[str, list[int]], data)
89def language_is_allowed(code: str) -> bool:
90 """
91 Check a language code is valid
92 """
93 return code in get_language_dict()
96def copy_resources_to_database(session: Session) -> None:
97 """
98 Syncs the source-of-truth data from files into the database. Call this at the end of a migration.
100 Foreign key constraints that refer to resource tables need to be set to DEFERRABLE.
102 We sync as follows:
104 1. Lock the table to be updated fully
105 2. Defer all constraints
106 3. Truncate the table
107 4. Re-insert everything
109 Truncating and recreating guarantees the data is fully in sync.
110 """
111 with open(resources_folder / "regions.json", "r") as f:
112 regions = [(region["alpha3"], region["name"]) for region in json.load(f)]
114 with open(resources_folder / "languages.json", "r") as f:
115 languages = [(language["code"], language["name"]) for language in json.load(f)]
117 timezone_areas_file = resources_folder / "timezone_areas.sql"
119 if not timezone_areas_file.exists():
120 if not config["DEV"]:
121 raise Exception("Missing timezone_areas.sql and not running in dev")
123 timezone_areas_file = resources_folder / "timezone_areas.sql-fake"
124 logger.info("Using fake timezone areas")
126 with open(timezone_areas_file, "r") as f:
127 tz_sql = f.read()
129 # set all constraints marked as DEFERRABLE to be checked at the end of this transaction, not immediately
130 session.execute(text("SET CONSTRAINTS ALL DEFERRED"))
132 session.execute(delete(Region))
133 for code, name in regions:
134 session.add(Region(code=code, name=name))
136 session.execute(delete(Language))
137 for code, name in languages:
138 session.add(Language(code=code, name=name))
140 session.execute(delete(TimezoneArea))
141 session.execute(text(tz_sql))