Coverage for app/backend/src/app.py: 0%

105 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 09:29 +0000

1import logging 

2import signal 

3import sys 

4import threading 

5from multiprocessing import Process 

6from os import environ 

7from tempfile import TemporaryDirectory 

8from types import TracebackType 

9 

10# these two lines need to be at the top of the file before we span child processes 

11# this temp dir will be destroyed when prometheus_multiproc_dir is destroyed, aka at the end of the program. 

12# Also note that this should only be done in the main process. 

13if __name__ == "__main__": 

14 prometheus_multiproc_dir = TemporaryDirectory() 

15 environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name 

16 

17# ruff: noqa: E402 

18 

19import sentry_sdk 

20from sentry_sdk.integrations import excepthook 

21from sqlalchemy.sql import text 

22 

23from couchers.config import config 

24from couchers.constants import API_BASE_PORT, API_WORKER_COUNT, GRACEFUL_SHUTDOWN_TIMEOUT, MEDIA_PORT 

25from couchers.db import apply_migrations, db_post_fork, session_scope 

26from couchers.experimentation import setup_experimentation 

27from couchers.i18n.locales import get_main_i18next 

28from couchers.jobs.worker import start_jobs_scheduler, start_jobs_worker 

29from couchers.metrics import create_prometheus_server 

30from couchers.profiling import setup_profiling 

31from couchers.server import create_main_server, create_media_server 

32from couchers.supervisor import supervise 

33from couchers.tracing import setup_tracing 

34from dummy_data import add_dummy_data 

35 

36config.check() 

37 

38logging.basicConfig( 

39 format="[%(process)5d:%(thread)20d] %(asctime)s: %(name)s:%(lineno)d: %(message)s", level=logging.INFO 

40) 

41logger = logging.getLogger(__name__) 

42 

43 

44def _run_api_server(port: int) -> None: 

45 try: 

46 db_post_fork() 

47 setup_experimentation() 

48 setup_tracing() 

49 setup_profiling(role="api", instance=f"api-{port}") 

50 

51 server = create_main_server(port=port, start_resource_sampler=True) 

52 server.start() 

53 logger.info(f"API worker serving on {port}") 

54 

55 terminate = threading.Event() 

56 signal.signal(signal.SIGTERM, lambda *_: terminate.set()) 

57 signal.signal(signal.SIGINT, lambda *_: terminate.set()) 

58 terminate.wait() 

59 

60 logger.info(f"API worker on {port} draining (up to {GRACEFUL_SHUTDOWN_TIMEOUT}s)") 

61 server.stop(GRACEFUL_SHUTDOWN_TIMEOUT).wait() 

62 except Exception: 

63 # multiprocessing would only print this to stderr; send the traceback to Sentry (and flush, since 

64 # the process is about to die and the parent will restart the container) before re-raising 

65 sentry_sdk.capture_exception() 

66 sentry_sdk.flush() 

67 raise 

68 

69 

70def start_api_worker(port: int) -> Process: 

71 worker = Process(target=_run_api_server, args=(port,)) 

72 worker.start() 

73 return worker 

74 

75 

76def log_unhandled_exception( 

77 exc_type: type[BaseException], 

78 exc_value: BaseException, 

79 exc_traceback: TracebackType | None, 

80) -> None: 

81 """Make sure that any unhandled exceptions will write to the logs""" 

82 if issubclass(exc_type, KeyboardInterrupt): 

83 # call the default excepthook saved at __excepthook__ 

84 sys.__excepthook__(exc_type, exc_value, exc_traceback) 

85 return 

86 logger.critical("Unhandled exception", exc_info=(exc_type, exc_value, exc_traceback)) 

87 

88 

89def common_init() -> None: 

90 sys.excepthook = log_unhandled_exception 

91 

92 if config.SENTRY_ENABLED: 

93 # Sends exception tracebacks to Sentry, a cloud service for collecting exceptions 

94 sentry_sdk.init( 

95 config.SENTRY_URL, 

96 traces_sample_rate=0.0, 

97 environment=config.COOKIE_DOMAIN, 

98 release=config.VERSION, 

99 # The global excepthook picks up already handled gRPC errors (e.g. grpc.StatusCode.NOT_FOUND) 

100 disabled_integrations=[ 

101 excepthook.ExcepthookIntegration(), 

102 ], 

103 ) 

104 

105 logger.info("Checking DB connection") 

106 with session_scope() as session: 

107 res = session.execute(text("SELECT 42;")) 

108 if list(res) != [(42,)]: 

109 raise Exception("Failed to connect to DB") 

110 

111 

112def main() -> None: 

113 logger.info("Running DB migrations") 

114 

115 apply_migrations() 

116 

117 get_main_i18next() # Force eager loading of translations 

118 

119 if config.ADD_DUMMY_DATA: 

120 add_dummy_data() 

121 

122 logger.info("Starting") 

123 

124 children: list[Process] = [] 

125 

126 if config.ROLE in ["scheduler", "all"]: 

127 scheduler = start_jobs_scheduler() 

128 scheduler.name = "scheduler" 

129 children.append(scheduler) 

130 

131 if config.ROLE in ["worker", "all"]: 

132 for i in range(config.BACKGROUND_WORKER_COUNT): 

133 worker = start_jobs_worker(i) 

134 worker.name = f"worker-{i}" 

135 children.append(worker) 

136 

137 # The multiprocessing start method is forkserver/spawn (Python 3.14 default; never 

138 # fork), so each worker runs its own per-process init — don't pin set_start_method("fork") to "simplify" 

139 # this, that reintroduces fork-after-threads hazards. 

140 if config.ROLE in ["api", "all"]: 

141 for port in range(API_BASE_PORT, API_BASE_PORT + API_WORKER_COUNT): 

142 api_worker = start_api_worker(port) 

143 api_worker.name = f"api-{port}" 

144 children.append(api_worker) 

145 

146 create_prometheus_server(8000) 

147 

148 # Must precede setup_tracing(), which reads the `trace_sample_ratio` flag. 

149 setup_experimentation() 

150 

151 setup_tracing() 

152 

153 media_server = None 

154 if config.ROLE in ["api", "all"]: 

155 media_server = create_media_server(port=MEDIA_PORT) 

156 media_server.start() 

157 logger.info(f"Media server serving on {MEDIA_PORT}") 

158 

159 logger.info("App started, supervising child processes") 

160 crashed = supervise(children, parent_servers=[media_server] if media_server is not None else []) 

161 

162 if crashed is not None: 

163 sys.exit(1) 

164 

165 

166if __name__ == "__main__": 

167 common_init() 

168 main() 

169elif __name__ == "__mp_main__": # processes created via multiprocessing 

170 common_init()