Coverage for yaptide/celery/tasks.py: 90%

152 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import contextlib 

2from dataclasses import dataclass 

3import logging 

4import tempfile 

5from datetime import datetime 

6from pathlib import Path 

7import threading 

8from typing import Optional 

9 

10from yaptide.batch.batch_methods import post_update 

11from yaptide.celery.utils.pymc import (average_estimators, command_to_run_fluka, command_to_run_shieldhit, 

12 execute_simulation_subprocess, get_fluka_estimators, get_shieldhit_estimators, 

13 get_tmp_dir, read_file, read_file_offline, read_fluka_file) 

14from yaptide.celery.utils.requests import (send_simulation_logfiles, send_simulation_results, send_task_update) 

15from yaptide.celery.simulation_worker import celery_app 

16from yaptide.utils.enums import EntityState 

17from yaptide.utils.sim_utils import (check_and_convert_payload_to_files_dict, estimators_to_list, simulation_logfiles, 

18 write_simulation_input_files) 

19 

20 

21@celery_app.task 

22def convert_input_files(payload_dict: dict) -> dict: 

23 """Function converting output""" 

24 files_dict = check_and_convert_payload_to_files_dict(payload_dict=payload_dict) 

25 return {"input_files": files_dict} 

26 

27 

28@celery_app.task(bind=True) 

29def run_single_simulation(self, 

30 files_dict: dict, 

31 task_id: int, 

32 update_key: str = '', 

33 simulation_id: int = None, 

34 keep_tmp_files: bool = False, 

35 sim_type: str = 'shieldhit') -> dict: 

36 """Function running single simulation""" 

37 # for the purpose of running this function in pytest we would like to have some control 

38 # on the temporary directory used by the function 

39 

40 logging.info("Running simulation, simulation_id: %s, task_id: %d", simulation_id, task_id) 

41 

42 logging.info("Sending initial update for task %d, setting celery id %s", task_id, self.request.id) 

43 send_task_update(simulation_id, task_id, update_key, {"celery_id": self.request.id}) 

44 

45 # we would like to have some control on the temporary directory used by the function 

46 tmp_dir = get_tmp_dir() 

47 logging.info("Temporary directory is: %s", tmp_dir) 

48 

49 # with tempfile.TemporaryDirectory(dir=tmp_dir) as tmp_dir_path: 

50 # use the selected temporary directory to create a temporary directory 

51 with (contextlib.nullcontext(tempfile.mkdtemp(dir=tmp_dir)) if keep_tmp_files else tempfile.TemporaryDirectory( 

52 dir=tmp_dir)) as tmp_work_dir: 

53 

54 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_work_dir)) 

55 logging.debug("Generated input files: %s", files_dict.keys()) 

56 

57 if sim_type == 'shieldhit': 

58 simulation_result = run_single_simulation_for_shieldhit(tmp_work_dir, task_id, update_key, simulation_id) 

59 elif sim_type == 'fluka': 

60 simulation_result = run_single_simulation_for_fluka(tmp_work_dir, task_id, update_key, simulation_id) 

61 

62 # there is no simulation output 

63 if not simulation_result.estimators_dict: 

64 # first we notify the backend that the task with simulation has failed 

65 logging.info("Simulation failed for task %d, sending update that it has failed", task_id) 

66 update_dict = {"task_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

67 send_task_update(simulation_id, task_id, update_key, update_dict) 

68 

69 # then we send the logfiles to the backend, if available 

70 logfiles = simulation_logfiles(path=Path(tmp_work_dir)) 

71 logging.info("Simulation failed, logfiles: %s", logfiles.keys()) 

72 # the method below is in particular broken, 

73 # as there may be several logfiles, for some of the tasks 

74 # lets imagine following sequence of actions: 

75 # task 1 fails, with some usefule message in the logfile, 

76 # i.e. after 100 primaries the SHIELD-HIT12A binary crashed 

77 # then the useful logfiles are being sent to the backend 

78 # task 2 fails later, but here the SHIELD-HIT12A binary crashes 

79 # at the beginning of the simulation, without producing of the logfiles 

80 # then again the logfiles are being sent to the backend, but this time they are empty 

81 # so the useful logfiles are overwritten by the empty ones 

82 # we temporarily disable sending logfiles to the backend 

83 # if logfiles: 

84 # pass 

85 # sending_logfiles_status = send_simulation_logfiles(simulation_id=simulation_id, 

86 # update_key=update_key, 

87 # logfiles=logfiles) 

88 # if not sending_logfiles_status: 

89 # logging.error("Sending logfiles failed for task %s", task_id) 

90 

91 # finally we return from the celery task, returning the logfiles and stdout/stderr as result 

92 return { 

93 "logfiles": logfiles, 

94 "stdout": simulation_result.command_stdout, 

95 "stderr": simulation_result.command_stderr, 

96 "simulation_id": simulation_id, 

97 "update_key": update_key 

98 } 

99 

100 # otherwise we have simulation output 

101 logging.debug("Converting simulation results to JSON") 

102 estimators = estimators_to_list(estimators_dict=simulation_result.estimators_dict, dir_path=Path(tmp_work_dir)) 

103 

104 # We do not have any information if monitoring process sent the last update 

105 # so we send it here to make sure that we have the end_time and COMPLETED state 

106 end_time = datetime.utcnow().isoformat(sep=" ") 

107 update_dict = { 

108 "task_state": EntityState.COMPLETED.value, 

109 "end_time": end_time, 

110 "simulated_primaries": simulation_result.requested_primaries, 

111 "requested_primaries": simulation_result.requested_primaries 

112 } 

113 send_task_update(simulation_id, task_id, update_key, update_dict) 

114 

115 # finally return from the celery task, returning the estimators and stdout/stderr as result 

116 # the estimators will be merged by subsequent celery task 

117 return {"estimators": estimators, "simulation_id": simulation_id, "update_key": update_key} 

118 

119 

120@dataclass 

121class SimulationTaskResult: 

122 """Class representing result of single simulation task""" 

123 

124 process_exit_success: bool 

125 command_stdout: str 

126 command_stderr: str 

127 simulated_primaries: int 

128 requested_primaries: int 

129 estimators_dict: dict 

130 

131 

132def run_single_simulation_for_shieldhit(tmp_work_dir: str, 

133 task_id: int, 

134 update_key: str = '', 

135 simulation_id: int = Optional[None]) -> SimulationTaskResult: 

136 """Function running single simulation for shieldhit""" 

137 command_as_list = command_to_run_shieldhit(dir_path=Path(tmp_work_dir), task_id=task_id) 

138 logging.info("Command to run SHIELD-HIT12A: %s", " ".join(command_as_list)) 

139 

140 command_stdout, command_stderr = '', '' 

141 simulated_primaries, requested_primaries = 0, 0 

142 event = threading.Event() 

143 

144 # start monitoring process if possible 

145 # is None if monitoring if monitor was not started 

146 task_monitor = monitor_shieldhit(event, tmp_work_dir, task_id, update_key, simulation_id) 

147 # run the simulation 

148 logging.info("Running SHIELD-HIT12A process in %s", tmp_work_dir) 

149 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess( 

150 dir_path=Path(tmp_work_dir), command_as_list=command_as_list) 

151 logging.info("SHIELD-HIT12A process finished with status %s", process_exit_success) 

152 

153 # terminate monitoring process 

154 if task_monitor: 

155 logging.debug("Terminating monitoring process for task %d", task_id) 

156 event.set() 

157 task_monitor.task.join() 

158 logging.debug("Monitoring process for task %d terminated", task_id) 

159 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend 

160 if task_monitor: 

161 simulated_primaries, requested_primaries = read_file_offline(task_monitor.path_to_monitor) 

162 

163 # both simulation execution and monitoring process are finished now, we can read the estimators 

164 estimators_dict = get_shieldhit_estimators(dir_path=Path(tmp_work_dir)) 

165 

166 return SimulationTaskResult(process_exit_success=process_exit_success, 

167 command_stdout=command_stdout, 

168 command_stderr=command_stderr, 

169 simulated_primaries=simulated_primaries, 

170 requested_primaries=requested_primaries, 

171 estimators_dict=estimators_dict) 

172 

173 

174def run_single_simulation_for_fluka(tmp_work_dir: str, 

175 task_id: int, 

176 update_key: str = '', 

177 simulation_id: Optional[int] = None) -> SimulationTaskResult: 

178 """Function running single simulation for shieldhit""" 

179 command_as_list = command_to_run_fluka(dir_path=Path(tmp_work_dir), task_id=task_id) 

180 logging.info("Command to run FLUKA: %s", " ".join(command_as_list)) 

181 

182 command_stdout, command_stderr = '', '' 

183 simulated_primaries, requested_primaries = 0, 0 

184 event = threading.Event() 

185 # start monitoring process if possible 

186 # is None if monitoring if monitor was not started 

187 task_monitor = monitor_fluka(event, tmp_work_dir, task_id, update_key, simulation_id) 

188 

189 # run the simulation 

190 logging.info("Running Fluka process in %s", tmp_work_dir) 

191 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess( 

192 dir_path=Path(tmp_work_dir), command_as_list=command_as_list) 

193 logging.info("Fluka process finished with status %s", process_exit_success) 

194 

195 # terminate monitoring process 

196 if task_monitor: 

197 logging.debug("Terminating monitoring process for task %s", task_id) 

198 event.set() 

199 task_monitor.task.join() 

200 logging.debug("Monitoring process for task %s terminated", task_id) 

201 # TO BE IMPLEMENTED 

202 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend 

203 # reading of the log file for fluka after simulation was finished 

204 # fluka copies the file back to main directory from temporary directory 

205 

206 # both simulation execution and monitoring process are finished now, we can read the estimators 

207 estimators_dict = get_fluka_estimators(dir_path=Path(tmp_work_dir)) 

208 

209 return SimulationTaskResult(process_exit_success=process_exit_success, 

210 command_stdout=command_stdout, 

211 command_stderr=command_stderr, 

212 simulated_primaries=simulated_primaries, 

213 requested_primaries=requested_primaries, 

214 estimators_dict=estimators_dict) 

215 

216 

217@celery_app.task 

218def set_merging_queued_state(results: list[dict]) -> list[dict]: 

219 """Celery task to set simulation state as MERGING_QUEUED""" 

220 logging.debug("send_state") 

221 simulation_id = results[0].get("simulation_id", None) 

222 update_key = results[0].get("update_key", None) 

223 if simulation_id and update_key: 

224 dict_to_send = { 

225 "sim_id": simulation_id, 

226 "job_state": EntityState.MERGING_QUEUED.value, 

227 "update_key": update_key 

228 } 

229 post_update(dict_to_send) 

230 return results 

231 

232 

233@celery_app.task 

234def merge_results(results: list[dict]) -> dict: 

235 """Merge results from multiple simulation's tasks""" 

236 logging.debug("Merging results from %d tasks", len(results)) 

237 logfiles = {} 

238 

239 averaged_estimators = None 

240 simulation_id = results[0].pop("simulation_id", None) 

241 update_key = results[0].pop("update_key", None) 

242 if simulation_id and update_key: 

243 dict_to_send = { 

244 "sim_id": simulation_id, 

245 "job_state": EntityState.MERGING_RUNNING.value, 

246 "update_key": update_key 

247 } 

248 post_update(dict_to_send) 

249 for i, result in enumerate(results): 

250 if simulation_id is None: 

251 simulation_id = result.pop("simulation_id", None) 

252 if update_key is None: 

253 update_key = result.pop("update_key", None) 

254 if "logfiles" in result: 

255 logfiles.update(result["logfiles"]) 

256 continue 

257 

258 if averaged_estimators is None: 

259 averaged_estimators: list[dict] = result.get("estimators", []) 

260 # There is nothing to average yet 

261 continue 

262 

263 averaged_estimators = average_estimators(averaged_estimators, result.get("estimators", []), i) 

264 

265 final_result = {"end_time": datetime.utcnow().isoformat(sep=" ")} 

266 

267 if len(logfiles.keys()) > 0 and not send_simulation_logfiles( 

268 simulation_id=simulation_id, update_key=update_key, logfiles=logfiles): 

269 final_result["logfiles"] = logfiles 

270 

271 if averaged_estimators: 

272 # send results to the backend and mark whole simulation as completed 

273 sending_results_ok = send_simulation_results(simulation_id=simulation_id, 

274 update_key=update_key, 

275 estimators=averaged_estimators) 

276 if not sending_results_ok: 

277 final_result["estimators"] = averaged_estimators 

278 

279 return final_result 

280 

281 

282@dataclass 

283class MonitorTask: 

284 """Class representing monitoring task""" 

285 

286 path_to_monitor: Path 

287 task: threading.Thread 

288 

289 

290def monitor_shieldhit(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str, 

291 simulation_id: str) -> Optional[MonitorTask]: 

292 """Function monitoring progress of SHIELD-HIT12A simulation""" 

293 # we would like to monitor the progress of simulation 

294 # this is done by reading the log file and sending the updates to the backend 

295 # if we have update_key and simulation_id the monitoring task can submit the updates to backend 

296 path_to_monitor = Path(tmp_work_dir) / f"shieldhit_{task_id:04d}.log" 

297 if update_key and simulation_id is not None: 

298 current_logging_level = logging.getLogger().getEffectiveLevel() 

299 task = threading.Thread(target=read_file, 

300 kwargs=dict(event=event, 

301 filepath=path_to_monitor, 

302 simulation_id=simulation_id, 

303 task_id=task_id, 

304 update_key=update_key, 

305 logging_level=current_logging_level)) 

306 task.start() 

307 logging.info("Started monitoring process for task %d", task_id) 

308 return MonitorTask(path_to_monitor=path_to_monitor, task=task) 

309 

310 logging.info("No monitoring processes started for task %d", task_id) 

311 return None 

312 

313 

314def monitor_fluka(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str, 

315 simulation_id: int) -> Optional[MonitorTask]: 

316 """Function running the monitoring process for Fluka simulation""" 

317 # we would like to monitor the progress of simulation 

318 # this is done by reading the log file and sending the updates to the backend 

319 # if we have update_key and simulation_id the monitoring task can submit the updates to backend 

320 # We use dir instead path, because fluka simulator generates direcoty with PID in name of its process 

321 dir_to_monitor = Path(tmp_work_dir) 

322 if update_key and simulation_id is not None: 

323 current_logging_level = logging.getLogger().getEffectiveLevel() 

324 task = threading.Thread(target=read_fluka_file, 

325 kwargs=dict(event=event, 

326 dirpath=dir_to_monitor, 

327 simulation_id=simulation_id, 

328 task_id=task_id, 

329 update_key=update_key, 

330 logging_level=current_logging_level)) 

331 

332 task.start() 

333 logging.info("Started monitoring process for task %d", task_id) 

334 return MonitorTask(path_to_monitor=dir_to_monitor, task=task) 

335 

336 logging.info("No monitoring processes started for task %d", task_id) 

337 return None