Coverage for app / backend / src / tests / test_db.py: 92%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-19 14:14 +0000

1import difflib 

2import os 

3import re 

4import subprocess 

5from pathlib import Path 

6from typing import Any 

7 

8import pytest 

9from google.protobuf import empty_pb2 

10from sqlalchemy import select 

11from sqlalchemy.sql import func 

12 

13from couchers.config import config 

14from couchers.db import _get_base_engine, apply_migrations, get_parent_node_at_location, session_scope 

15from couchers.jobs.handlers import DatabaseInconsistencyError, check_database_consistency 

16from couchers.models import User 

17from couchers.utils import ( 

18 is_valid_email, 

19 is_valid_name, 

20 is_valid_user_id, 

21 is_valid_username, 

22 parse_date, 

23) 

24from tests.fixtures.db import ( 

25 create_schema_from_models, 

26 drop_database, 

27 generate_user, 

28 pg_dump_is_available, 

29 populate_testing_resources, 

30) 

31from tests.test_communities import create_1d_point, get_community_id, testing_communities # noqa 

32 

33 

34def test_is_valid_user_id() -> None: 

35 assert is_valid_user_id("10") 

36 assert not is_valid_user_id("1a") 

37 assert not is_valid_user_id("01") 

38 

39 

40def test_is_valid_email() -> None: 

41 assert is_valid_email("a@b.cc") 

42 assert is_valid_email("te.st+email.valid@a.org.au.xx.yy") 

43 assert is_valid_email("invalid@yahoo.co.uk") 

44 assert is_valid_email("user+tag@example.com") 

45 assert is_valid_email("first.last@example.com") 

46 assert not is_valid_email("invalid@.yahoo.co.uk") 

47 assert not is_valid_email("test email@couchers.org") 

48 assert not is_valid_email(".testemail@couchers.org") 

49 assert not is_valid_email("testemail@couchersorg") 

50 assert not is_valid_email("b@xxb....blabla") 

51 # dot immediately before @ (the original bug) 

52 assert not is_valid_email("user.@example.com") 

53 # consecutive dots in local part 

54 assert not is_valid_email("user..name@example.com") 

55 

56 

57def test_is_valid_username() -> None: 

58 assert is_valid_username("user") 

59 assert is_valid_username("us") 

60 assert is_valid_username("us_er") 

61 assert is_valid_username("us_er1") 

62 assert not is_valid_username("us_") 

63 assert not is_valid_username("u") 

64 assert not is_valid_username("1us") 

65 assert not is_valid_username("User") 

66 

67 

68def test_is_valid_name() -> None: 

69 assert is_valid_name("a") 

70 assert is_valid_name("a b") 

71 assert is_valid_name("1") 

72 assert is_valid_name("老子") 

73 assert not is_valid_name(" ") 

74 assert not is_valid_name("") 

75 assert not is_valid_name(" ") 

76 

77 

78def test_parse_date() -> None: 

79 assert parse_date("2020-01-01") is not None 

80 assert parse_date("1900-01-01") is not None 

81 assert parse_date("2099-01-01") is not None 

82 assert not parse_date("2019-02-29") 

83 assert not parse_date("2019-22-01") 

84 assert not parse_date("2020-1-01") 

85 assert not parse_date("20-01-01") 

86 assert not parse_date("01-01-2020") 

87 assert not parse_date("2020/01/01") 

88 

89 

90def test_get_parent_node_at_location(testing_communities): 

91 with session_scope() as session: 

92 w_id = get_community_id(session, "Global") # 0 to 100 

93 c1_id = get_community_id(session, "Country 1") # 0 to 50 

94 c1r1_id = get_community_id(session, "Country 1, Region 1") # 0 to 10 

95 c1r1c1_id = get_community_id(session, "Country 1, Region 1, City 1") # 0 to 5 

96 c1r1c2_id = get_community_id(session, "Country 1, Region 1, City 2") # 7 to 10 

97 c1r2_id = get_community_id(session, "Country 1, Region 2") # 20 to 25 

98 c1r2c1_id = get_community_id(session, "Country 1, Region 2, City 1") # 21 to 23 

99 c2_id = get_community_id(session, "Country 2") # 52 to 100 

100 c2r1_id = get_community_id(session, "Country 2, Region 1") # 52 to 71 

101 c2r1c1_id = get_community_id(session, "Country 2, Region 1, City 1") # 53 to 70 

102 

103 assert get_parent_node_at_location(session, create_1d_point(1)).id == c1r1c1_id # type: ignore[union-attr] 

104 assert get_parent_node_at_location(session, create_1d_point(3)).id == c1r1c1_id # type: ignore[union-attr] 

105 assert get_parent_node_at_location(session, create_1d_point(6)).id == c1r1_id # type: ignore[union-attr] 

106 assert get_parent_node_at_location(session, create_1d_point(8)).id == c1r1c2_id # type: ignore[union-attr] 

107 assert get_parent_node_at_location(session, create_1d_point(15)).id == c1_id # type: ignore[union-attr] 

108 assert get_parent_node_at_location(session, create_1d_point(51)).id == w_id # type: ignore[union-attr] 

109 

110 

111def pg_dump() -> str: 

112 return subprocess.run( 

113 ["pg_dump", "-s", config["DATABASE_CONNECTION_STRING"]], stdout=subprocess.PIPE, encoding="ascii", check=True 

114 ).stdout 

115 

116 

117def sort_pg_dump_output(output: str) -> str: 

118 """Sorts the tables, functions and indices dumped by pg_dump in 

119 alphabetic order. Also sorts all lists enclosed with parentheses 

120 in alphabetic order. 

121 """ 

122 # Temporary replace newline with another character for easier 

123 # pattern matching. 

124 s = output.replace("\n", "§") 

125 

126 # Parameter lists are enclosed with parentheses and every entry 

127 # ends with a comma last on the line. 

128 s = re.sub(r" \(§(.*?)§\);", lambda m: " (§" + ",§".join(sorted(m.group(1).split(",§"))) + "§);", s) 

129 

130 # The header for all objects (tables, functions, indices, etc.) 

131 # seems to all start with two dashes and a space. We don't care 

132 # which kind of object it is here. 

133 s = "§-- ".join(sorted(s.split("§-- "))) 

134 

135 # Switch our temporary newline replacement to real newline. 

136 return s.replace("§", "\n") 

137 

138 

139def test_sort_pg_dump_output() -> None: 

140 assert sort_pg_dump_output(" (\nb,\nc,\na\n);\n") == " (\na,\nb,\nc\n);\n" 

141 

142 

143def strip_leading_whitespace(lines: list[str]) -> list[str]: 

144 return [s.lstrip() for s in lines] 

145 

146 

147@pytest.fixture 

148def restore_db_after_migration_test(db): 

149 try: 

150 yield 

151 finally: 

152 # Dispose the engine's connection pool since we dropped/recreated PostGIS extension, 

153 # which invalidates cached operator OIDs in existing connections 

154 engine = _get_base_engine() 

155 engine.dispose() 

156 

157 # Restore test resources since we destroyed the database 

158 # This is needed because setup_testdb is session-scoped and won't run again 

159 with engine.connect() as conn: 

160 populate_testing_resources(conn) 

161 conn.commit() 

162 

163 

164@pytest.mark.skipif(not pg_dump_is_available(), reason="Can't run migration tests without pg_dump") 

165def test_migrations(db, testconfig: dict[str, Any], restore_db_after_migration_test) -> None: 

166 """ 

167 Compares the database schema built up from migrations with the 

168 schema built by models.py. Both scenarios are started from an 

169 empty database and dumped with pg_dump. Any unexplainable 

170 differences in the output are reported in unified diff format and 

171 fail the test. 

172 

173 Note: this takes about 2 minutes in CI, because the real timezone_areas.sql file 

174 is used, and it's big. Locally, timezone_areas.sql-fake is used. 

175 """ 

176 drop_database() 

177 # rebuild it with alembic migrations 

178 apply_migrations() 

179 

180 with_migrations = pg_dump() 

181 

182 drop_database() 

183 # create everything from the current models, not incrementally 

184 # through migrations 

185 create_schema_from_models() 

186 

187 from_scratch = pg_dump() 

188 

189 # Save the raw schemas to files for CI artifacts 

190 schema_output_dir = os.environ.get("TEST_SCHEMA_OUTPUT_DIR") 

191 if schema_output_dir: 191 ↛ 197line 191 didn't jump to line 197 because the condition on line 191 was always true

192 output_path = Path(schema_output_dir) 

193 output_path.mkdir(parents=True, exist_ok=True) 

194 (output_path / "schema_from_migrations.sql").write_text(with_migrations) 

195 (output_path / "schema_from_models.sql").write_text(from_scratch) 

196 

197 def message(s: str) -> list[str]: 

198 s = sort_pg_dump_output(s) 

199 

200 # filter out alembic tables 

201 s = "\n-- ".join(x for x in s.split("\n-- ") if not x.startswith("Name: alembic_")) 

202 

203 # filter out \restrict and \unrestrict lines (Postgres 16+) 

204 s = "\n".join( 

205 line for line in s.splitlines() if not line.startswith("\\restrict") and not line.startswith("\\unrestrict") 

206 ) 

207 

208 return strip_leading_whitespace(s.splitlines()) 

209 

210 diff = "\n".join( 

211 difflib.unified_diff(message(with_migrations), message(from_scratch), fromfile="migrations", tofile="model") 

212 ) 

213 print(diff) 

214 success = diff == "" 

215 assert success 

216 

217 

218def test_slugify(db): 

219 with session_scope() as session: 

220 assert session.execute(func.slugify("this is a test")).scalar_one() == "this-is-a-test" 

221 assert session.execute(func.slugify("this is ä test")).scalar_one() == "this-is-a-test" 

222 # nothing here gets converted to ascci by unaccent, so it should be empty 

223 assert session.execute(func.slugify("Создай группу своего города")).scalar_one() == "slug" 

224 assert session.execute(func.slugify("Detta är ett test!")).scalar_one() == "detta-ar-ett-test" 

225 assert session.execute(func.slugify("@#(*$&!@#")).scalar_one() == "slug" 

226 assert ( 

227 session.execute( 

228 func.slugify("This has a lot ‒ at least relatively speaking ‒ of punctuation! :)") 

229 ).scalar_one() 

230 == "this-has-a-lot-at-least-relatively-speaking-of-punctuation" 

231 ) 

232 assert ( 

233 session.execute(func.slugify("Multiple - #@! - non-ascii chars")).scalar_one() == "multiple-non-ascii-chars" 

234 ) 

235 assert session.execute(func.slugify("123")).scalar_one() == "123" 

236 assert ( 

237 session.execute( 

238 func.slugify( 

239 "A sentence that is over 64 chars long and where the last thing would be replaced by a dash" 

240 ) 

241 ).scalar_one() 

242 == "a-sentence-that-is-over-64-chars-long-and-where-the-last-thing" 

243 ) 

244 

245 

246def test_database_consistency_check(db, testconfig: dict[str, Any]) -> None: 

247 """The database consistency check should pass with valid user/gallery setup""" 

248 # Create a few users (which auto-creates their profile galleries) 

249 generate_user() 

250 generate_user() 

251 generate_user() 

252 

253 # This should not raise any exceptions 

254 check_database_consistency(empty_pb2.Empty()) 

255 

256 # Now break consistency by removing a user's profile gallery 

257 with session_scope() as session: 

258 user = session.execute(select(User).where(User.deleted_at.is_(None)).limit(1)).scalar_one() 

259 user.profile_gallery_id = None 

260 

261 # This should now raise an exception 

262 with pytest.raises(DatabaseInconsistencyError): 

263 check_database_consistency(empty_pb2.Empty()) 

264 

265 

266def test_migration_ordinals() -> None: 

267 """ 

268 Validates that all migration files use ordinal revision IDs and form a 

269 linear chain. Each migration NNNN must have: 

270 - revision = "NNNN" 

271 - down_revision = "NNNN-1" (or None for 0001) 

272 - filename starting with NNNN_ 

273 """ 

274 versions_dir = Path(__file__).parent.parent / "couchers" / "migrations" / "versions" 

275 

276 migration_files = sorted(f for f in versions_dir.glob("*.py") if re.match(r"^\d{4}_", f.name)) 

277 assert len(migration_files) > 0, f"No migration files found in {versions_dir}" 

278 

279 errors = [] 

280 prev_ordinal = None 

281 

282 for path in migration_files: 

283 filename_match = re.match(r"^(\d{4})_", path.name) 

284 assert filename_match, f"Migration filename does not start with ordinal: {path.name}" 

285 file_ordinal = filename_match.group(1) 

286 

287 content = path.read_text() 

288 

289 rev_match = re.search(r'^revision\s*=\s*"([^"]+)"', content, re.MULTILINE) 

290 down_match = re.search(r"^down_revision\s*=\s*(None|\"([^\"]+)\")", content, re.MULTILINE) 

291 

292 if not rev_match: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true

293 errors.append(f"{path.name}: missing 'revision' variable") 

294 continue 

295 if not down_match: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true

296 errors.append(f"{path.name}: missing 'down_revision' variable") 

297 continue 

298 

299 revision = rev_match.group(1) 

300 down_revision = down_match.group(2) # None if down_revision = None 

301 

302 if revision != file_ordinal: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 errors.append(f'{path.name}: revision = "{revision}" does not match filename ordinal "{file_ordinal}"') 

304 

305 if file_ordinal == "0001": 

306 if down_revision is not None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 errors.append(f'{path.name}: first migration must have down_revision = None, got "{down_revision}"') 

308 else: 

309 expected_down = f"{int(file_ordinal) - 1:04d}" 

310 if down_revision != expected_down: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 errors.append(f'{path.name}: down_revision = "{down_revision}" but expected "{expected_down}"') 

312 

313 # Check for gaps in the sequence 

314 expected_ordinal = f"{int(prev_ordinal) + 1:04d}" if prev_ordinal else "0001" 

315 if file_ordinal != expected_ordinal: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 errors.append(f"{path.name}: expected ordinal {expected_ordinal}, got {file_ordinal} (gap in sequence)") 

317 

318 prev_ordinal = file_ordinal 

319 

320 assert not errors, "Migration ordinal errors:\n" + "\n".join(f" - {e}" for e in errors)