Coverage for src/couchers/utils.py: 93%

179 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-20 11:53 +0000

1import http.cookies 

2import re 

3import typing 

4from collections.abc import Sequence 

5from datetime import date, datetime, timedelta 

6from email.utils import formatdate 

7from types import SimpleNamespace 

8from typing import Any 

9from zoneinfo import ZoneInfo 

10 

11import pytz 

12from geoalchemy2 import WKBElement, WKTElement 

13from geoalchemy2.shape import from_shape, to_shape 

14from google.protobuf.duration_pb2 import Duration 

15from google.protobuf.timestamp_pb2 import Timestamp 

16from shapely.geometry import Point, Polygon, shape 

17from sqlalchemy import Function, cast 

18from sqlalchemy.sql import func 

19from sqlalchemy.types import DateTime 

20 

21from couchers.config import config 

22from couchers.constants import EMAIL_REGEX, PREFERRED_LANGUAGE_COOKIE_EXPIRY 

23from couchers.crypto import decrypt_page_token, encrypt_page_token 

24 

25utc = pytz.UTC 

26 

27 

28# When a user logs in, they can basically input one of three things: user id, username, or email 

29# These are three non-intersecting sets 

30# * user_ids are numeric representations in base 10 

31# * usernames are alphanumeric + underscores, at least 2 chars long, and don't start with a number, and don't start or end with underscore 

32# * emails are just whatever stack overflow says emails are ;) 

33 

34 

35def is_valid_user_id(field: str) -> bool: 

36 """ 

37 Checks if it's a string representing a base 10 integer not starting with 0 

38 """ 

39 return re.match(r"[1-9][0-9]*$", field) is not None 

40 

41 

42def is_valid_username(field: str) -> bool: 

43 """ 

44 Checks if it's an alphanumeric + underscore, lowercase string, at least 

45 two characters long, and starts with a letter, ends with alphanumeric 

46 """ 

47 return re.match(r"[a-z][0-9a-z_]*[a-z0-9]$", field) is not None 

48 

49 

50def is_valid_name(field: str) -> bool: 

51 """ 

52 Checks if it has at least one non-whitespace character 

53 """ 

54 return re.match(r"\S+", field) is not None 

55 

56 

57def is_valid_email(field: str) -> bool: 

58 return re.match(EMAIL_REGEX, field) is not None 

59 

60 

61def Timestamp_from_datetime(dt: datetime) -> Timestamp: 

62 pb_ts = Timestamp() 

63 pb_ts.FromDatetime(dt) 

64 return pb_ts 

65 

66 

67def Duration_from_timedelta(dt: timedelta) -> Duration: 

68 pb_d = Duration() 

69 pb_d.FromTimedelta(dt) 

70 return pb_d 

71 

72 

73def parse_date(date_str: str) -> date | None: 

74 """ 

75 Parses a date-only string in the format "YYYY-MM-DD" returning None if it fails 

76 """ 

77 try: 

78 return date.fromisoformat(date_str) 

79 except ValueError: 

80 return None 

81 

82 

83def date_to_api(date_obj: date) -> str: 

84 return date_obj.isoformat() 

85 

86 

87def to_aware_datetime(ts: Timestamp) -> datetime: 

88 """ 

89 Turns a protobuf Timestamp object into a timezone-aware datetime 

90 """ 

91 return ts.ToDatetime(tzinfo=utc) 

92 

93 

94def now() -> datetime: 

95 return datetime.now(utc) 

96 

97 

98def minimum_allowed_birthdate() -> date: 

99 """ 

100 Most recent birthdate allowed to register (must be 18 years minimum) 

101 

102 This approximation works on leap days! 

103 """ 

104 return today() - timedelta(days=365.25 * 18) 

105 

106 

107def today() -> date: 

108 """ 

109 Date only in UTC 

110 """ 

111 return now().date() 

112 

113 

114def now_in_timezone(tz: str) -> datetime: 

115 """ 

116 tz should be tzdata identifier, e.g. America/New_York 

117 """ 

118 return datetime.now(pytz.timezone(tz)) 

119 

120 

121def today_in_timezone(tz: str) -> date: 

122 """ 

123 tz should be tzdata identifier, e.g. America/New_York 

124 """ 

125 return now_in_timezone(tz).date() 

126 

127 

128# Note: be very careful with ordering of lat/lng! 

129# In a lot of cases they come as (lng, lat), but us humans tend to use them from GPS as (lat, lng)... 

130# When entering as EPSG4326, we also need it in (lng, lat) 

131 

132 

133def wrap_coordinate(lat: int, lng: int) -> tuple[int, int]: 

134 """ 

135 Wraps (lat, lng) point in the EPSG4326 format 

136 """ 

137 

138 def __wrap_gen(deg: int, ct: int, adj: int) -> int: 

139 if deg > ct: 

140 deg -= adj 

141 if deg < -ct: 

142 deg += adj 

143 return deg 

144 

145 def __wrap_flip(deg: int, ct: int, adj: int) -> int: 

146 if deg > ct: 

147 deg = -deg + adj 

148 if deg < -ct: 

149 deg = -deg - adj 

150 return deg 

151 

152 def __wrap_rem(deg: int, ct: int = 360) -> int: 

153 if deg > ct: 

154 deg = deg % ct 

155 if deg < -ct: 

156 deg = deg % -ct 

157 return deg 

158 

159 if lng < -180 or lng > 180 or lat < -90 or lat > 90: 

160 lng = __wrap_rem(lng) 

161 lat = __wrap_rem(lat) 

162 lng = __wrap_gen(lng, 180, 360) 

163 lat = __wrap_flip(lat, 180, 180) 

164 lat = __wrap_flip(lat, 90, 180) 

165 if lng == -180: 

166 lng = 180 

167 if lng == -360: 

168 lng = 0 

169 

170 return lat, lng 

171 

172 

173def create_coordinate(lat: int, lng: int) -> WKBElement: 

174 """ 

175 Creates a WKT point from a (lat, lng) tuple in EPSG4326 coordinate system (normal GPS-coordinates) 

176 """ 

177 lat, lng = wrap_coordinate(lat, lng) 

178 return from_shape(Point(lng, lat), srid=4326) 

179 

180 

181def create_polygon_lat_lng(points: list[list[int]]) -> WKBElement: 

182 """ 

183 Creates a EPSG4326 WKT polygon from a list of (lat, lng) tuples 

184 """ 

185 return from_shape(Polygon([(lng, lat) for (lat, lng) in points]), srid=4326) 

186 

187 

188def create_polygon_lng_lat(points: list[list[int]]) -> WKBElement: 

189 """ 

190 Creates a EPSG4326 WKT polygon from a list of (lng, lat) tuples 

191 """ 

192 return from_shape(Polygon(points), srid=4326) 

193 

194 

195def geojson_to_geom(geojson: dict[str, Any]) -> WKBElement: 

196 """ 

197 Turns GeoJSON to PostGIS geom data in EPSG4326 

198 """ 

199 return from_shape(shape(geojson), srid=4326) 

200 

201 

202def to_multi(polygon: WKBElement) -> Function[Any]: 

203 return func.ST_Multi(polygon) 

204 

205 

206def get_coordinates(geom: WKBElement | WKTElement | None) -> tuple[int, int] | None: 

207 """ 

208 Returns EPSG4326 (lat, lng) pair for a given WKT geom point or None if the input is not truthy 

209 """ 

210 if geom: 

211 shp = to_shape(geom) 

212 # note the funniness with 4326 normally being (x, y) = (lng, lat) 

213 return (shp.y, shp.x) 

214 else: 

215 return None 

216 

217 

218def http_date(dt: datetime | None = None) -> str: 

219 """ 

220 Format the datetime for HTTP cookies 

221 """ 

222 if not dt: 

223 dt = now() 

224 return formatdate(dt.timestamp(), usegmt=True) 

225 

226 

227def _create_tasty_cookie(name: str, value: Any, expiry: datetime, httponly: bool) -> str: 

228 cookie: http.cookies.Morsel[str] = http.cookies.Morsel() 

229 cookie.set(name, str(value), str(value)) 

230 # tell the browser when to stop sending the cookie 

231 cookie["expires"] = http_date(expiry) 

232 # restrict to our domain, note if there's no domain, it won't include subdomains 

233 cookie["domain"] = config["COOKIE_DOMAIN"] 

234 # path so that it's accessible for all API requests, otherwise defaults to something like /org.couchers.auth/ 

235 cookie["path"] = "/" 

236 if config["DEV"]: 

237 # send only on requests from first-party domains 

238 cookie["samesite"] = "Strict" 

239 else: 

240 # send on all requests, requires Secure 

241 cookie["samesite"] = "None" 

242 # only set cookie on HTTPS sites in production 

243 cookie["secure"] = True 

244 # not accessible from javascript 

245 cookie["httponly"] = httponly 

246 

247 return cookie.OutputString() 

248 

249 

250def create_session_cookies(token: str, user_id: int, expiry: datetime) -> list[str]: 

251 """ 

252 Creates our session cookies. 

253 

254 We have two: the secure session token (in couchers-sesh) that's inaccessible to javascript, and the user id (in couchers-user-id) which the javascript frontend can access, so that it knows when it's logged in/out 

255 """ 

256 return [ 

257 _create_tasty_cookie("couchers-sesh", token, expiry, httponly=True), 

258 _create_tasty_cookie("couchers-user-id", user_id, expiry, httponly=False), 

259 ] 

260 

261 

262def create_lang_cookie(lang: str) -> list[str]: 

263 return [ 

264 _create_tasty_cookie("NEXT_LOCALE", lang, expiry=(now() + PREFERRED_LANGUAGE_COOKIE_EXPIRY), httponly=False) 

265 ] 

266 

267 

268def parse_session_cookie(headers: dict[str, str | bytes]) -> str | None: 

269 """ 

270 Returns our session cookie value (aka token) or None 

271 """ 

272 if "cookie" not in headers: 

273 return None 

274 

275 cookie_str = typing.cast(str, headers["cookie"]) 

276 

277 # parse the cookie 

278 cookie = http.cookies.SimpleCookie(cookie_str).get("couchers-sesh") 

279 

280 if not cookie: 

281 return None 

282 

283 return cookie.value 

284 

285 

286def parse_user_id_cookie(headers: dict[str, str | bytes]) -> str | None: 

287 """ 

288 Returns our session cookie value (aka token) or None 

289 """ 

290 if "cookie" not in headers: 

291 return None 

292 

293 cookie_str = typing.cast(str, headers["cookie"]) 

294 

295 # parse the cookie 

296 cookie = http.cookies.SimpleCookie(cookie_str).get("couchers-user-id") 

297 

298 if not cookie: 

299 return None 

300 

301 return cookie.value 

302 

303 

304def parse_ui_lang_cookie(headers: dict[str, str | bytes]) -> str | None: 

305 """ 

306 Returns language cookie or None 

307 """ 

308 if "cookie" not in headers: 

309 return None 

310 

311 cookie_str = typing.cast(str, headers["cookie"]) 

312 

313 # else parse the cookie & return its value 

314 cookie = http.cookies.SimpleCookie(cookie_str).get("NEXT_LOCALE") 

315 

316 if not cookie: 

317 return None 

318 

319 return cookie.value 

320 

321 

322def parse_api_key(headers: dict[str, str | bytes]) -> str | None: 

323 """ 

324 Returns a bearer token (API key) from the `authorization` header, or None if invalid/not present 

325 """ 

326 if "authorization" not in headers: 

327 return None 

328 

329 authorization = headers["authorization"] 

330 if isinstance(authorization, bytes): 

331 authorization = authorization.decode("utf-8") 

332 

333 if not authorization.startswith("Bearer "): 

334 return None 

335 

336 return authorization[7:] 

337 

338 

339def remove_duplicates_retain_order[T](list_: Sequence[T]) -> list[T]: 

340 out = [] 

341 for item in list_: 

342 if item not in out: 

343 out.append(item) 

344 return out 

345 

346 

347def date_in_timezone(date_: date, timezone: str) -> Function[Any]: 

348 """ 

349 Given a naive postgres date object (postgres doesn't have tzd dates), returns a timezone-aware timestamp for the 

350 start of that date in that timezone. E.g. if postgres is in 'America/New_York', 

351 

352 SET SESSION TIME ZONE 'America/New_York'; 

353 

354 CREATE TABLE tz_trouble (to_date date, timezone text); 

355 

356 INSERT INTO tz_trouble(to_date, timezone) VALUES 

357 ('2021-03-10'::date, 'Australia/Sydney'), 

358 ('2021-03-20'::date, 'Europe/Berlin'), 

359 ('2021-04-15'::date, 'America/New_York'); 

360 

361 SELECT timezone(timezone, to_date::timestamp) FROM tz_trouble; 

362 

363 The result is: 

364 

365 timezone 

366 ------------------------ 

367 2021-03-09 08:00:00-05 

368 2021-03-19 19:00:00-04 

369 2021-04-15 00:00:00-04 

370 """ 

371 return func.timezone(timezone, cast(date_, DateTime(timezone=False))) 

372 

373 

374def millis_from_dt(dt: datetime) -> int: 

375 return round(1000 * dt.timestamp()) 

376 

377 

378def dt_from_millis(millis: int) -> datetime: 

379 return datetime.fromtimestamp(millis / 1000, tz=utc) 

380 

381 

382def dt_to_page_token(dt: datetime) -> str: 

383 """ 

384 Python has datetime resolution equal to 1 micro, as does postgres 

385 

386 We pray to deities that this never changes 

387 """ 

388 assert datetime.resolution == timedelta(microseconds=1) 

389 return encrypt_page_token(str(round(1_000_000 * dt.timestamp()))) 

390 

391 

392def dt_from_page_token(page_token: str) -> datetime: 

393 # see above comment 

394 return datetime.fromtimestamp(int(decrypt_page_token(page_token)) / 1_000_000, tz=utc) 

395 

396 

397def last_active_coarsen(dt: datetime) -> datetime: 

398 """ 

399 Coarsens a "last active" time to the accuracy we use for last active times, currently to the last hour, e.g. if the current time is 27th June 2021, 16:53 UTC, this returns 27th June 2021, 16:00 UTC 

400 """ 

401 return dt.replace(minute=0, second=0, microsecond=0) 

402 

403 

404def get_tz_as_text(tz_name: str) -> str: 

405 return datetime.now(tz=ZoneInfo(tz_name)).strftime("%Z/UTC%z") 

406 

407 

408def make_logged_out_context() -> SimpleNamespace: 

409 return SimpleNamespace(user_id=0)