Coverage for app / backend / src / tests / test_db.py: 99%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-03 06:18 +0000

1import difflib 

2import os 

3import re 

4import subprocess 

5from pathlib import Path 

6from typing import Any 

7 

8import pytest 

9from google.protobuf import empty_pb2 

10from sqlalchemy import select 

11from sqlalchemy.sql import func 

12 

13from couchers.config import config 

14from couchers.db import _get_base_engine, apply_migrations, get_parent_node_at_location, session_scope 

15from couchers.jobs.handlers import DatabaseInconsistencyError, check_database_consistency 

16from couchers.models import User 

17from couchers.utils import ( 

18 is_valid_email, 

19 is_valid_name, 

20 is_valid_user_id, 

21 is_valid_username, 

22 parse_date, 

23) 

24from tests.fixtures.db import ( 

25 create_schema_from_models, 

26 drop_database, 

27 generate_user, 

28 pg_dump_is_available, 

29 populate_testing_resources, 

30) 

31from tests.test_communities import create_1d_point, get_community_id, testing_communities # noqa 

32 

33 

34def test_is_valid_user_id() -> None: 

35 assert is_valid_user_id("10") 

36 assert not is_valid_user_id("1a") 

37 assert not is_valid_user_id("01") 

38 

39 

40def test_is_valid_email() -> None: 

41 assert is_valid_email("a@b.cc") 

42 assert is_valid_email("te.st+email.valid@a.org.au.xx.yy") 

43 assert is_valid_email("invalid@yahoo.co.uk") 

44 assert not is_valid_email("invalid@.yahoo.co.uk") 

45 assert not is_valid_email("test email@couchers.org") 

46 assert not is_valid_email(".testemail@couchers.org") 

47 assert not is_valid_email("testemail@couchersorg") 

48 assert not is_valid_email("b@xxb....blabla") 

49 

50 

51def test_is_valid_username() -> None: 

52 assert is_valid_username("user") 

53 assert is_valid_username("us") 

54 assert is_valid_username("us_er") 

55 assert is_valid_username("us_er1") 

56 assert not is_valid_username("us_") 

57 assert not is_valid_username("u") 

58 assert not is_valid_username("1us") 

59 assert not is_valid_username("User") 

60 

61 

62def test_is_valid_name() -> None: 

63 assert is_valid_name("a") 

64 assert is_valid_name("a b") 

65 assert is_valid_name("1") 

66 assert is_valid_name("老子") 

67 assert not is_valid_name(" ") 

68 assert not is_valid_name("") 

69 assert not is_valid_name(" ") 

70 

71 

72def test_parse_date() -> None: 

73 assert parse_date("2020-01-01") is not None 

74 assert parse_date("1900-01-01") is not None 

75 assert parse_date("2099-01-01") is not None 

76 assert not parse_date("2019-02-29") 

77 assert not parse_date("2019-22-01") 

78 assert not parse_date("2020-1-01") 

79 assert not parse_date("20-01-01") 

80 assert not parse_date("01-01-2020") 

81 assert not parse_date("2020/01/01") 

82 

83 

84def test_get_parent_node_at_location(testing_communities): 

85 with session_scope() as session: 

86 w_id = get_community_id(session, "Global") # 0 to 100 

87 c1_id = get_community_id(session, "Country 1") # 0 to 50 

88 c1r1_id = get_community_id(session, "Country 1, Region 1") # 0 to 10 

89 c1r1c1_id = get_community_id(session, "Country 1, Region 1, City 1") # 0 to 5 

90 c1r1c2_id = get_community_id(session, "Country 1, Region 1, City 2") # 7 to 10 

91 c1r2_id = get_community_id(session, "Country 1, Region 2") # 20 to 25 

92 c1r2c1_id = get_community_id(session, "Country 1, Region 2, City 1") # 21 to 23 

93 c2_id = get_community_id(session, "Country 2") # 52 to 100 

94 c2r1_id = get_community_id(session, "Country 2, Region 1") # 52 to 71 

95 c2r1c1_id = get_community_id(session, "Country 2, Region 1, City 1") # 53 to 70 

96 

97 assert get_parent_node_at_location(session, create_1d_point(1)).id == c1r1c1_id # type: ignore[union-attr] 

98 assert get_parent_node_at_location(session, create_1d_point(3)).id == c1r1c1_id # type: ignore[union-attr] 

99 assert get_parent_node_at_location(session, create_1d_point(6)).id == c1r1_id # type: ignore[union-attr] 

100 assert get_parent_node_at_location(session, create_1d_point(8)).id == c1r1c2_id # type: ignore[union-attr] 

101 assert get_parent_node_at_location(session, create_1d_point(15)).id == c1_id # type: ignore[union-attr] 

102 assert get_parent_node_at_location(session, create_1d_point(51)).id == w_id # type: ignore[union-attr] 

103 

104 

105def pg_dump() -> str: 

106 return subprocess.run( 

107 ["pg_dump", "-s", config["DATABASE_CONNECTION_STRING"]], stdout=subprocess.PIPE, encoding="ascii", check=True 

108 ).stdout 

109 

110 

111def sort_pg_dump_output(output: str) -> str: 

112 """Sorts the tables, functions and indices dumped by pg_dump in 

113 alphabetic order. Also sorts all lists enclosed with parentheses 

114 in alphabetic order. 

115 """ 

116 # Temporary replace newline with another character for easier 

117 # pattern matching. 

118 s = output.replace("\n", "§") 

119 

120 # Parameter lists are enclosed with parentheses and every entry 

121 # ends with a comma last on the line. 

122 s = re.sub(r" \(§(.*?)§\);", lambda m: " (§" + ",§".join(sorted(m.group(1).split(",§"))) + "§);", s) 

123 

124 # The header for all objects (tables, functions, indices, etc.) 

125 # seems to all start with two dashes and a space. We don't care 

126 # which kind of object it is here. 

127 s = "§-- ".join(sorted(s.split("§-- "))) 

128 

129 # Switch our temporary newline replacement to real newline. 

130 return s.replace("§", "\n") 

131 

132 

133def test_sort_pg_dump_output() -> None: 

134 assert sort_pg_dump_output(" (\nb,\nc,\na\n);\n") == " (\na,\nb,\nc\n);\n" 

135 

136 

137def strip_leading_whitespace(lines: list[str]) -> list[str]: 

138 return [s.lstrip() for s in lines] 

139 

140 

141@pytest.fixture 

142def restore_db_after_migration_test(db): 

143 try: 

144 yield 

145 finally: 

146 # Dispose the engine's connection pool since we dropped/recreated PostGIS extension, 

147 # which invalidates cached operator OIDs in existing connections 

148 engine = _get_base_engine() 

149 engine.dispose() 

150 

151 # Restore test resources since we destroyed the database 

152 # This is needed because setup_testdb is session-scoped and won't run again 

153 with engine.connect() as conn: 

154 populate_testing_resources(conn) 

155 conn.commit() 

156 

157 

158@pytest.mark.skipif(not pg_dump_is_available(), reason="Can't run migration tests without pg_dump") 

159def test_migrations(db, testconfig: dict[str, Any], restore_db_after_migration_test) -> None: 

160 """ 

161 Compares the database schema built up from migrations with the 

162 schema built by models.py. Both scenarios are started from an 

163 empty database and dumped with pg_dump. Any unexplainable 

164 differences in the output are reported in unified diff format and 

165 fail the test. 

166 

167 Note: this takes about 2 minutes in CI, because the real timezone_areas.sql file 

168 is used, and it's big. Locally, timezone_areas.sql-fake is used. 

169 """ 

170 drop_database() 

171 # rebuild it with alembic migrations 

172 apply_migrations() 

173 

174 with_migrations = pg_dump() 

175 

176 drop_database() 

177 # create everything from the current models, not incrementally 

178 # through migrations 

179 create_schema_from_models() 

180 

181 from_scratch = pg_dump() 

182 

183 # Save the raw schemas to files for CI artifacts 

184 schema_output_dir = os.environ.get("TEST_SCHEMA_OUTPUT_DIR") 

185 if schema_output_dir: 185 ↛ 191line 185 didn't jump to line 191 because the condition on line 185 was always true

186 output_path = Path(schema_output_dir) 

187 output_path.mkdir(parents=True, exist_ok=True) 

188 (output_path / "schema_from_migrations.sql").write_text(with_migrations) 

189 (output_path / "schema_from_models.sql").write_text(from_scratch) 

190 

191 def message(s: str) -> list[str]: 

192 s = sort_pg_dump_output(s) 

193 

194 # filter out alembic tables 

195 s = "\n-- ".join(x for x in s.split("\n-- ") if not x.startswith("Name: alembic_")) 

196 

197 # filter out \restrict and \unrestrict lines (Postgres 16+) 

198 s = "\n".join( 

199 line for line in s.splitlines() if not line.startswith("\\restrict") and not line.startswith("\\unrestrict") 

200 ) 

201 

202 return strip_leading_whitespace(s.splitlines()) 

203 

204 diff = "\n".join( 

205 difflib.unified_diff(message(with_migrations), message(from_scratch), fromfile="migrations", tofile="model") 

206 ) 

207 print(diff) 

208 success = diff == "" 

209 assert success 

210 

211 

212def test_slugify(db): 

213 with session_scope() as session: 

214 assert session.execute(func.slugify("this is a test")).scalar_one() == "this-is-a-test" 

215 assert session.execute(func.slugify("this is ä test")).scalar_one() == "this-is-a-test" 

216 # nothing here gets converted to ascci by unaccent, so it should be empty 

217 assert session.execute(func.slugify("Создай группу своего города")).scalar_one() == "slug" 

218 assert session.execute(func.slugify("Detta är ett test!")).scalar_one() == "detta-ar-ett-test" 

219 assert session.execute(func.slugify("@#(*$&!@#")).scalar_one() == "slug" 

220 assert ( 

221 session.execute( 

222 func.slugify("This has a lot ‒ at least relatively speaking ‒ of punctuation! :)") 

223 ).scalar_one() 

224 == "this-has-a-lot-at-least-relatively-speaking-of-punctuation" 

225 ) 

226 assert ( 

227 session.execute(func.slugify("Multiple - #@! - non-ascii chars")).scalar_one() == "multiple-non-ascii-chars" 

228 ) 

229 assert session.execute(func.slugify("123")).scalar_one() == "123" 

230 assert ( 

231 session.execute( 

232 func.slugify( 

233 "A sentence that is over 64 chars long and where the last thing would be replaced by a dash" 

234 ) 

235 ).scalar_one() 

236 == "a-sentence-that-is-over-64-chars-long-and-where-the-last-thing" 

237 ) 

238 

239 

240def test_database_consistency_check(db, testconfig: dict[str, Any]) -> None: 

241 """The database consistency check should pass with valid user/gallery setup""" 

242 # Create a few users (which auto-creates their profile galleries) 

243 generate_user() 

244 generate_user() 

245 generate_user() 

246 

247 # This should not raise any exceptions 

248 check_database_consistency(empty_pb2.Empty()) 

249 

250 # Now break consistency by removing a user's profile gallery 

251 with session_scope() as session: 

252 user = session.execute(select(User).where(User.is_deleted == False).limit(1)).scalar_one() 

253 user.profile_gallery_id = None 

254 

255 # This should now raise an exception 

256 with pytest.raises(DatabaseInconsistencyError): 

257 check_database_consistency(empty_pb2.Empty())