Coverage for app/backend/src/tests/test_db.py: 93%

185 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 09:29 +0000

1import difflib 

2import os 

3import re 

4import subprocess 

5from pathlib import Path 

6from typing import Any 

7 

8import pytest 

9from google.protobuf import empty_pb2 

10from sqlalchemy import select 

11from sqlalchemy.sql import func 

12 

13from couchers.config import config 

14from couchers.db import _get_base_engine, apply_migrations, get_parent_node_at_location, session_scope 

15from couchers.jobs.handlers import DatabaseInconsistencyError, check_database_consistency 

16from couchers.models import User 

17from couchers.utils import ( 

18 is_valid_email, 

19 is_valid_name, 

20 is_valid_user_id, 

21 is_valid_username, 

22 parse_date, 

23) 

24from tests.fixtures.db import ( 

25 create_schema_from_models, 

26 drop_database, 

27 generate_user, 

28 pg_dump_is_available, 

29 populate_testing_resources, 

30) 

31from tests.test_communities import create_1d_point, get_community_id, testing_communities # noqa 

32 

33 

34def test_is_valid_user_id() -> None: 

35 assert is_valid_user_id("10") 

36 assert not is_valid_user_id("1a") 

37 assert not is_valid_user_id("01") 

38 

39 

40def test_is_valid_email() -> None: 

41 assert is_valid_email("a@b.cc") 

42 assert is_valid_email("te.st+email.valid@a.org.au.xx.yy") 

43 assert is_valid_email("invalid@yahoo.co.uk") 

44 assert is_valid_email("user+tag@example.com") 

45 assert is_valid_email("first.last@example.com") 

46 assert not is_valid_email("invalid@.yahoo.co.uk") 

47 assert not is_valid_email("test email@couchers.org") 

48 assert not is_valid_email(".testemail@couchers.org") 

49 assert not is_valid_email("testemail@couchersorg") 

50 assert not is_valid_email("b@xxb....blabla") 

51 # dot immediately before @ (the original bug) 

52 assert not is_valid_email("user.@example.com") 

53 # consecutive dots in local part 

54 assert not is_valid_email("user..name@example.com") 

55 

56 

57def test_is_valid_username() -> None: 

58 assert is_valid_username("user") 

59 assert is_valid_username("us") 

60 assert is_valid_username("us_er") 

61 assert is_valid_username("us_er1") 

62 assert not is_valid_username("us_") 

63 assert not is_valid_username("u") 

64 assert not is_valid_username("1us") 

65 assert not is_valid_username("User") 

66 

67 

68def test_is_valid_name() -> None: 

69 # valid names 

70 assert is_valid_name("ab") 

71 assert is_valid_name("a b") 

72 assert is_valid_name("O'Connor") 

73 assert is_valid_name("Jean-Luc") 

74 assert is_valid_name("老子") 

75 

76 # invalid: too short 

77 assert not is_valid_name("a") 

78 # invalid: only whitespace 

79 assert not is_valid_name(" ") 

80 assert not is_valid_name("") 

81 assert not is_valid_name(" ") 

82 assert not is_valid_name(" ") 

83 # invalid: leading/trailing whitespace 

84 assert not is_valid_name(" ab") 

85 assert not is_valid_name("ab ") 

86 assert not is_valid_name(" ab ") 

87 # invalid: contains characters outside of letters/whitespace/'/- 

88 assert not is_valid_name("1") 

89 # invalid: too long 

90 assert not is_valid_name("a" * 101) 

91 

92 

93def test_parse_date() -> None: 

94 assert parse_date("2020-01-01") is not None 

95 assert parse_date("1900-01-01") is not None 

96 assert parse_date("2099-01-01") is not None 

97 assert not parse_date("2019-02-29") 

98 assert not parse_date("2019-22-01") 

99 assert not parse_date("2020-1-01") 

100 assert not parse_date("20-01-01") 

101 assert not parse_date("01-01-2020") 

102 assert not parse_date("2020/01/01") 

103 

104 

105def test_get_parent_node_at_location(testing_communities): 

106 with session_scope() as session: 

107 w_id = get_community_id(session, "Global") # 0 to 100 

108 c1_id = get_community_id(session, "Country 1") # 0 to 50 

109 c1r1_id = get_community_id(session, "Country 1, Region 1") # 0 to 10 

110 c1r1c1_id = get_community_id(session, "Country 1, Region 1, City 1") # 0 to 5 

111 c1r1c2_id = get_community_id(session, "Country 1, Region 1, City 2") # 7 to 10 

112 c1r2_id = get_community_id(session, "Country 1, Region 2") # 20 to 25 

113 c1r2c1_id = get_community_id(session, "Country 1, Region 2, City 1") # 21 to 23 

114 c2_id = get_community_id(session, "Country 2") # 52 to 100 

115 c2r1_id = get_community_id(session, "Country 2, Region 1") # 52 to 71 

116 c2r1c1_id = get_community_id(session, "Country 2, Region 1, City 1") # 53 to 70 

117 

118 assert get_parent_node_at_location(session, create_1d_point(1)).id == c1r1c1_id # type: ignore[union-attr] 

119 assert get_parent_node_at_location(session, create_1d_point(3)).id == c1r1c1_id # type: ignore[union-attr] 

120 assert get_parent_node_at_location(session, create_1d_point(6)).id == c1r1_id # type: ignore[union-attr] 

121 assert get_parent_node_at_location(session, create_1d_point(8)).id == c1r1c2_id # type: ignore[union-attr] 

122 assert get_parent_node_at_location(session, create_1d_point(15)).id == c1_id # type: ignore[union-attr] 

123 assert get_parent_node_at_location(session, create_1d_point(51)).id == w_id # type: ignore[union-attr] 

124 

125 

126def pg_dump() -> str: 

127 return subprocess.run( 

128 ["pg_dump", "-s", config.DATABASE_CONNECTION_STRING], stdout=subprocess.PIPE, encoding="ascii", check=True 

129 ).stdout 

130 

131 

132def sort_pg_dump_output(output: str) -> str: 

133 """Sorts the tables, functions and indices dumped by pg_dump in 

134 alphabetic order. Also sorts all lists enclosed with parentheses 

135 in alphabetic order. 

136 """ 

137 # Temporary replace newline with another character for easier 

138 # pattern matching. 

139 s = output.replace("\n", "§") 

140 

141 # Parameter lists are enclosed with parentheses and every entry 

142 # ends with a comma last on the line. 

143 s = re.sub(r" \(§(.*?)§\);", lambda m: " (§" + ",§".join(sorted(m.group(1).split(",§"))) + "§);", s) 

144 

145 # The header for all objects (tables, functions, indices, etc.) 

146 # seems to all start with two dashes and a space. We don't care 

147 # which kind of object it is here. 

148 s = "§-- ".join(sorted(s.split("§-- "))) 

149 

150 # Switch our temporary newline replacement to real newline. 

151 return s.replace("§", "\n") 

152 

153 

154def test_sort_pg_dump_output() -> None: 

155 assert sort_pg_dump_output(" (\nb,\nc,\na\n);\n") == " (\na,\nb,\nc\n);\n" 

156 

157 

158def strip_leading_whitespace(lines: list[str]) -> list[str]: 

159 return [s.lstrip() for s in lines] 

160 

161 

162@pytest.fixture 

163def restore_db_after_migration_test(db): 

164 try: 

165 yield 

166 finally: 

167 # Dispose the engine's connection pool since we dropped/recreated PostGIS extension, 

168 # which invalidates cached operator OIDs in existing connections 

169 engine = _get_base_engine() 

170 engine.dispose() 

171 

172 # Restore test resources since we destroyed the database 

173 # This is needed because setup_testdb is session-scoped and won't run again 

174 with engine.connect() as conn: 

175 populate_testing_resources(conn) 

176 conn.commit() 

177 

178 

179@pytest.mark.skipif(not pg_dump_is_available(), reason="Can't run migration tests without pg_dump") 

180def test_migrations(db, testconfig: dict[str, Any], restore_db_after_migration_test) -> None: 

181 """ 

182 Compares the database schema built up from migrations with the 

183 schema built by models.py. Both scenarios are started from an 

184 empty database and dumped with pg_dump. Any unexplainable 

185 differences in the output are reported in unified diff format and 

186 fail the test. 

187 

188 Note: this takes about 2 minutes in CI, because the real timezone_areas.sql file 

189 is used, and it's big. Locally, timezone_areas.sql-fake is used. 

190 """ 

191 drop_database() 

192 # rebuild it with alembic migrations 

193 apply_migrations() 

194 

195 with_migrations = pg_dump() 

196 

197 drop_database() 

198 # create everything from the current models, not incrementally 

199 # through migrations 

200 create_schema_from_models() 

201 

202 from_scratch = pg_dump() 

203 

204 # Save the raw schemas to files for CI artifacts 

205 schema_output_dir = os.environ.get("TEST_SCHEMA_OUTPUT_DIR") 

206 if schema_output_dir: 206 ↛ 212line 206 didn't jump to line 212 because the condition on line 206 was always true

207 output_path = Path(schema_output_dir) 

208 output_path.mkdir(parents=True, exist_ok=True) 

209 (output_path / "schema_from_migrations.sql").write_text(with_migrations) 

210 (output_path / "schema_from_models.sql").write_text(from_scratch) 

211 

212 def message(s: str) -> list[str]: 

213 s = sort_pg_dump_output(s) 

214 

215 # filter out alembic tables 

216 s = "\n-- ".join(x for x in s.split("\n-- ") if not x.startswith("Name: alembic_")) 

217 

218 # filter out \restrict and \unrestrict lines (Postgres 16+) 

219 s = "\n".join( 

220 line for line in s.splitlines() if not line.startswith("\\restrict") and not line.startswith("\\unrestrict") 

221 ) 

222 

223 return strip_leading_whitespace(s.splitlines()) 

224 

225 diff = "\n".join( 

226 difflib.unified_diff(message(with_migrations), message(from_scratch), fromfile="migrations", tofile="model") 

227 ) 

228 print(diff) 

229 success = diff == "" 

230 assert success 

231 

232 

233def test_slugify(db): 

234 with session_scope() as session: 

235 assert session.execute(func.slugify("this is a test")).scalar_one() == "this-is-a-test" 

236 assert session.execute(func.slugify("this is ä test")).scalar_one() == "this-is-a-test" 

237 # nothing here gets converted to ascci by unaccent, so it should be empty 

238 assert session.execute(func.slugify("Создай группу своего города")).scalar_one() == "slug" 

239 assert session.execute(func.slugify("Detta är ett test!")).scalar_one() == "detta-ar-ett-test" 

240 assert session.execute(func.slugify("@#(*$&!@#")).scalar_one() == "slug" 

241 assert ( 

242 session.execute( 

243 func.slugify("This has a lot ‒ at least relatively speaking ‒ of punctuation! :)") 

244 ).scalar_one() 

245 == "this-has-a-lot-at-least-relatively-speaking-of-punctuation" 

246 ) 

247 assert ( 

248 session.execute(func.slugify("Multiple - #@! - non-ascii chars")).scalar_one() == "multiple-non-ascii-chars" 

249 ) 

250 assert session.execute(func.slugify("123")).scalar_one() == "123" 

251 assert ( 

252 session.execute( 

253 func.slugify( 

254 "A sentence that is over 64 chars long and where the last thing would be replaced by a dash" 

255 ) 

256 ).scalar_one() 

257 == "a-sentence-that-is-over-64-chars-long-and-where-the-last-thing" 

258 ) 

259 

260 

261def test_database_consistency_check(db, testconfig: dict[str, Any]) -> None: 

262 """The database consistency check should pass with valid user/gallery setup""" 

263 # Create a few users (which auto-creates their profile galleries) 

264 generate_user() 

265 generate_user() 

266 generate_user() 

267 

268 # This should not raise any exceptions 

269 check_database_consistency(empty_pb2.Empty()) 

270 

271 # Now break consistency by removing a user's profile gallery 

272 with session_scope() as session: 

273 user = session.execute(select(User).where(User.deleted_at.is_(None)).limit(1)).scalar_one() 

274 user.profile_gallery_id = None 

275 

276 # This should now raise an exception 

277 with pytest.raises(DatabaseInconsistencyError): 

278 check_database_consistency(empty_pb2.Empty()) 

279 

280 

281def test_migration_ordinals() -> None: 

282 """ 

283 Validates that all migration files use ordinal revision IDs and form a 

284 linear chain. Each migration NNNN must have: 

285 - revision = "NNNN" 

286 - down_revision = "NNNN-1" (or None for 0001) 

287 - filename starting with NNNN_ 

288 """ 

289 versions_dir = Path(__file__).parent.parent / "couchers" / "migrations" / "versions" 

290 

291 migration_files = sorted(f for f in versions_dir.glob("*.py") if re.match(r"^\d{4}_", f.name)) 

292 assert len(migration_files) > 0, f"No migration files found in {versions_dir}" 

293 

294 errors = [] 

295 prev_ordinal = None 

296 

297 for path in migration_files: 

298 filename_match = re.match(r"^(\d{4})_", path.name) 

299 assert filename_match, f"Migration filename does not start with ordinal: {path.name}" 

300 file_ordinal = filename_match.group(1) 

301 

302 content = path.read_text() 

303 

304 rev_match = re.search(r'^revision\s*=\s*"([^"]+)"', content, re.MULTILINE) 

305 down_match = re.search(r"^down_revision\s*=\s*(None|\"([^\"]+)\")", content, re.MULTILINE) 

306 

307 if not rev_match: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true

308 errors.append(f"{path.name}: missing 'revision' variable") 

309 continue 

310 if not down_match: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 errors.append(f"{path.name}: missing 'down_revision' variable") 

312 continue 

313 

314 revision = rev_match.group(1) 

315 down_revision = down_match.group(2) # None if down_revision = None 

316 

317 if revision != file_ordinal: 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 errors.append(f'{path.name}: revision = "{revision}" does not match filename ordinal "{file_ordinal}"') 

319 

320 if file_ordinal == "0001": 

321 if down_revision is not None: 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true

322 errors.append(f'{path.name}: first migration must have down_revision = None, got "{down_revision}"') 

323 else: 

324 expected_down = f"{int(file_ordinal) - 1:04d}" 

325 if down_revision != expected_down: 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true

326 errors.append(f'{path.name}: down_revision = "{down_revision}" but expected "{expected_down}"') 

327 

328 # Check for gaps in the sequence 

329 expected_ordinal = f"{int(prev_ordinal) + 1:04d}" if prev_ordinal else "0001" 

330 if file_ordinal != expected_ordinal: 330 ↛ 331line 330 didn't jump to line 331 because the condition on line 330 was never true

331 errors.append(f"{path.name}: expected ordinal {expected_ordinal}, got {file_ordinal} (gap in sequence)") 

332 

333 prev_ordinal = file_ordinal 

334 

335 assert not errors, "Migration ordinal errors:\n" + "\n".join(f" - {e}" for e in errors)