Coverage for yaptide/celery/tasks.py: 90%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import contextlib 

2from dataclasses import dataclass 

3import logging 

4import tempfile 

5from datetime import datetime 

6from pathlib import Path 

7import threading 

8from typing import Optional 

9 

10from yaptide.celery.utils.pymc import (average_estimators, command_to_run_fluka, command_to_run_shieldhit, 

11 execute_simulation_subprocess, get_fluka_estimators, get_shieldhit_estimators, 

12 get_tmp_dir, read_file, read_file_offline, read_fluka_file) 

13from yaptide.celery.utils.requests import (send_simulation_logfiles, send_simulation_results, send_task_update) 

14from yaptide.celery.simulation_worker import celery_app 

15from yaptide.utils.enums import EntityState 

16from yaptide.utils.sim_utils import (check_and_convert_payload_to_files_dict, estimators_to_list, simulation_logfiles, 

17 write_simulation_input_files) 

18 

19 

20@celery_app.task 

21def convert_input_files(payload_dict: dict) -> dict: 

22 """Function converting output""" 

23 files_dict = check_and_convert_payload_to_files_dict(payload_dict=payload_dict) 

24 return {"input_files": files_dict} 

25 

26 

27@celery_app.task(bind=True) 

28def run_single_simulation(self, 

29 files_dict: dict, 

30 task_id: int, 

31 update_key: str = '', 

32 simulation_id: int = None, 

33 keep_tmp_files: bool = False, 

34 sim_type: str = 'shieldhit') -> dict: 

35 """Function running single simulation""" 

36 # for the purpose of running this function in pytest we would like to have some control 

37 # on the temporary directory used by the function 

38 

39 logging.info("Running simulation, simulation_id: %s, task_id: %d", simulation_id, task_id) 

40 

41 logging.info("Sending initial update for task %d, setting celery id %s", task_id, self.request.id) 

42 send_task_update(simulation_id, task_id, update_key, {"celery_id": self.request.id}) 

43 

44 # we would like to have some control on the temporary directory used by the function 

45 tmp_dir = get_tmp_dir() 

46 logging.info("Temporary directory is: %s", tmp_dir) 

47 

48 # with tempfile.TemporaryDirectory(dir=tmp_dir) as tmp_dir_path: 

49 # use the selected temporary directory to create a temporary directory 

50 with (contextlib.nullcontext(tempfile.mkdtemp(dir=tmp_dir)) if keep_tmp_files else tempfile.TemporaryDirectory( 

51 dir=tmp_dir)) as tmp_work_dir: 

52 

53 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_work_dir)) 

54 logging.debug("Generated input files: %s", files_dict.keys()) 

55 

56 if sim_type == 'shieldhit': 

57 simulation_result = run_single_simulation_for_shieldhit(tmp_work_dir, task_id, update_key, simulation_id) 

58 elif sim_type == 'fluka': 

59 simulation_result = run_single_simulation_for_fluka(tmp_work_dir, task_id, update_key, simulation_id) 

60 

61 # there is no simulation output 

62 if not simulation_result.estimators_dict: 

63 # first we notify the backend that the task with simulation has failed 

64 logging.info("Simulation failed for task %d, sending update that it has failed", task_id) 

65 update_dict = {"task_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

66 send_task_update(simulation_id, task_id, update_key, update_dict) 

67 

68 # then we send the logfiles to the backend, if available 

69 logfiles = simulation_logfiles(path=Path(tmp_work_dir)) 

70 logging.info("Simulation failed, logfiles: %s", logfiles.keys()) 

71 # the method below is in particular broken, 

72 # as there may be several logfiles, for some of the tasks 

73 # lets imagine following sequence of actions: 

74 # task 1 fails, with some usefule message in the logfile, 

75 # i.e. after 100 primaries the SHIELD-HIT12A binary crashed 

76 # then the useful logfiles are being sent to the backend 

77 # task 2 fails later, but here the SHIELD-HIT12A binary crashes 

78 # at the beginning of the simulation, without producing of the logfiles 

79 # then again the logfiles are being sent to the backend, but this time they are empty 

80 # so the useful logfiles are overwritten by the empty ones 

81 # we temporarily disable sending logfiles to the backend 

82 # if logfiles: 

83 # pass 

84 # sending_logfiles_status = send_simulation_logfiles(simulation_id=simulation_id, 

85 # update_key=update_key, 

86 # logfiles=logfiles) 

87 # if not sending_logfiles_status: 

88 # logging.error("Sending logfiles failed for task %s", task_id) 

89 

90 # finally we return from the celery task, returning the logfiles and stdout/stderr as result 

91 return { 

92 "logfiles": logfiles, 

93 "stdout": simulation_result.command_stdout, 

94 "stderr": simulation_result.command_stderr, 

95 "simulation_id": simulation_id, 

96 "update_key": update_key 

97 } 

98 

99 # otherwise we have simulation output 

100 logging.debug("Converting simulation results to JSON") 

101 estimators = estimators_to_list(estimators_dict=simulation_result.estimators_dict, dir_path=Path(tmp_work_dir)) 

102 

103 # We do not have any information if monitoring process sent the last update 

104 # so we send it here to make sure that we have the end_time and COMPLETED state 

105 end_time = datetime.utcnow().isoformat(sep=" ") 

106 update_dict = { 

107 "task_state": EntityState.COMPLETED.value, 

108 "end_time": end_time, 

109 "simulated_primaries": simulation_result.requested_primaries, 

110 "requested_primaries": simulation_result.requested_primaries 

111 } 

112 send_task_update(simulation_id, task_id, update_key, update_dict) 

113 

114 # finally return from the celery task, returning the estimators and stdout/stderr as result 

115 # the estimators will be merged by subsequent celery task 

116 return {"estimators": estimators, "simulation_id": simulation_id, "update_key": update_key} 

117 

118 

119@dataclass 

120class SimulationTaskResult: 

121 """Class representing result of single simulation task""" 

122 

123 process_exit_success: bool 

124 command_stdout: str 

125 command_stderr: str 

126 simulated_primaries: int 

127 requested_primaries: int 

128 estimators_dict: dict 

129 

130 

131def run_single_simulation_for_shieldhit(tmp_work_dir: str, 

132 task_id: int, 

133 update_key: str = '', 

134 simulation_id: int = Optional[None]) -> SimulationTaskResult: 

135 """Function running single simulation for shieldhit""" 

136 command_as_list = command_to_run_shieldhit(dir_path=Path(tmp_work_dir), task_id=task_id) 

137 logging.info("Command to run SHIELD-HIT12A: %s", " ".join(command_as_list)) 

138 

139 command_stdout, command_stderr = '', '' 

140 simulated_primaries, requested_primaries = 0, 0 

141 event = threading.Event() 

142 

143 # start monitoring process if possible 

144 # is None if monitoring if monitor was not started 

145 task_monitor = monitor_shieldhit(event, tmp_work_dir, task_id, update_key, simulation_id) 

146 # run the simulation 

147 logging.info("Running SHIELD-HIT12A process in %s", tmp_work_dir) 

148 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess( 

149 dir_path=Path(tmp_work_dir), command_as_list=command_as_list) 

150 logging.info("SHIELD-HIT12A process finished with status %s", process_exit_success) 

151 

152 # terminate monitoring process 

153 if task_monitor: 

154 logging.debug("Terminating monitoring process for task %d", task_id) 

155 event.set() 

156 task_monitor.task.join() 

157 logging.debug("Monitoring process for task %d terminated", task_id) 

158 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend 

159 if task_monitor: 

160 simulated_primaries, requested_primaries = read_file_offline(task_monitor.path_to_monitor) 

161 

162 # both simulation execution and monitoring process are finished now, we can read the estimators 

163 estimators_dict = get_shieldhit_estimators(dir_path=Path(tmp_work_dir)) 

164 

165 return SimulationTaskResult(process_exit_success=process_exit_success, 

166 command_stdout=command_stdout, 

167 command_stderr=command_stderr, 

168 simulated_primaries=simulated_primaries, 

169 requested_primaries=requested_primaries, 

170 estimators_dict=estimators_dict) 

171 

172 

173def run_single_simulation_for_fluka(tmp_work_dir: str, 

174 task_id: int, 

175 update_key: str = '', 

176 simulation_id: Optional[int] = None) -> SimulationTaskResult: 

177 """Function running single simulation for shieldhit""" 

178 command_as_list = command_to_run_fluka(dir_path=Path(tmp_work_dir), task_id=task_id) 

179 logging.info("Command to run FLUKA: %s", " ".join(command_as_list)) 

180 

181 command_stdout, command_stderr = '', '' 

182 simulated_primaries, requested_primaries = 0, 0 

183 event = threading.Event() 

184 # start monitoring process if possible 

185 # is None if monitoring if monitor was not started 

186 task_monitor = monitor_fluka(event, tmp_work_dir, task_id, update_key, simulation_id) 

187 

188 # run the simulation 

189 logging.info("Running Fluka process in %s", tmp_work_dir) 

190 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess( 

191 dir_path=Path(tmp_work_dir), command_as_list=command_as_list) 

192 logging.info("Fluka process finished with status %s", process_exit_success) 

193 

194 # terminate monitoring process 

195 if task_monitor: 

196 logging.debug("Terminating monitoring process for task %s", task_id) 

197 event.set() 

198 task_monitor.task.join() 

199 logging.debug("Monitoring process for task %s terminated", task_id) 

200 # TO BE IMPLEMENTED 

201 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend 

202 # reading of the log file for fluka after simulation was finished 

203 # fluka copies the file back to main directory from temporary directory 

204 

205 # both simulation execution and monitoring process are finished now, we can read the estimators 

206 estimators_dict = get_fluka_estimators(dir_path=Path(tmp_work_dir)) 

207 

208 return SimulationTaskResult(process_exit_success=process_exit_success, 

209 command_stdout=command_stdout, 

210 command_stderr=command_stderr, 

211 simulated_primaries=simulated_primaries, 

212 requested_primaries=requested_primaries, 

213 estimators_dict=estimators_dict) 

214 

215 

216@celery_app.task 

217def merge_results(results: list[dict]) -> dict: 

218 """Merge results from multiple simulation's tasks""" 

219 logging.debug("Merging results from %d tasks", len(results)) 

220 logfiles = {} 

221 

222 averaged_estimators = None 

223 simulation_id = results[0].pop("simulation_id", None) 

224 update_key = results[0].pop("simulation_id", None) 

225 for i, result in enumerate(results): 

226 if simulation_id is None: 

227 simulation_id = result.pop("simulation_id", None) 

228 if update_key is None: 

229 update_key = result.pop("update_key", None) 

230 if "logfiles" in result: 

231 logfiles.update(result["logfiles"]) 

232 continue 

233 

234 if averaged_estimators is None: 

235 averaged_estimators: list[dict] = result.get("estimators", []) 

236 # There is nothing to average yet 

237 continue 

238 

239 averaged_estimators = average_estimators(averaged_estimators, result.get("estimators", []), i) 

240 

241 final_result = {"end_time": datetime.utcnow().isoformat(sep=" ")} 

242 

243 if len(logfiles.keys()) > 0 and not send_simulation_logfiles( 

244 simulation_id=simulation_id, update_key=update_key, logfiles=logfiles): 

245 final_result["logfiles"] = logfiles 

246 

247 if averaged_estimators: 

248 # send results to the backend and mark whole simulation as completed 

249 sending_results_ok = send_simulation_results(simulation_id=simulation_id, 

250 update_key=update_key, 

251 estimators=averaged_estimators) 

252 if not sending_results_ok: 

253 final_result["estimators"] = averaged_estimators 

254 

255 return final_result 

256 

257 

258@dataclass 

259class MonitorTask: 

260 """Class representing monitoring task""" 

261 

262 path_to_monitor: Path 

263 task: threading.Thread 

264 

265 

266def monitor_shieldhit(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str, 

267 simulation_id: str) -> Optional[MonitorTask]: 

268 """Function monitoring progress of SHIELD-HIT12A simulation""" 

269 # we would like to monitor the progress of simulation 

270 # this is done by reading the log file and sending the updates to the backend 

271 # if we have update_key and simulation_id the monitoring task can submit the updates to backend 

272 path_to_monitor = Path(tmp_work_dir) / f"shieldhit_{task_id:04d}.log" 

273 if update_key and simulation_id is not None: 

274 current_logging_level = logging.getLogger().getEffectiveLevel() 

275 task = threading.Thread(target=read_file, 

276 kwargs=dict(event=event, 

277 filepath=path_to_monitor, 

278 simulation_id=simulation_id, 

279 task_id=task_id, 

280 update_key=update_key, 

281 logging_level=current_logging_level)) 

282 task.start() 

283 logging.info("Started monitoring process for task %d", task_id) 

284 return MonitorTask(path_to_monitor=path_to_monitor, task=task) 

285 

286 logging.info("No monitoring processes started for task %d", task_id) 

287 return None 

288 

289 

290def monitor_fluka(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str, 

291 simulation_id: int) -> Optional[MonitorTask]: 

292 """Function running the monitoring process for Fluka simulation""" 

293 # we would like to monitor the progress of simulation 

294 # this is done by reading the log file and sending the updates to the backend 

295 # if we have update_key and simulation_id the monitoring task can submit the updates to backend 

296 # We use dir instead path, because fluka simulator generates direcoty with PID in name of its process 

297 dir_to_monitor = Path(tmp_work_dir) 

298 if update_key and simulation_id is not None: 

299 current_logging_level = logging.getLogger().getEffectiveLevel() 

300 task = threading.Thread(target=read_fluka_file, 

301 kwargs=dict(event=event, 

302 dirpath=dir_to_monitor, 

303 simulation_id=simulation_id, 

304 task_id=task_id, 

305 update_key=update_key, 

306 logging_level=current_logging_level)) 

307 

308 task.start() 

309 logging.info("Started monitoring process for task %d", task_id) 

310 return MonitorTask(path_to_monitor=dir_to_monitor, task=task) 

311 

312 logging.info("No monitoring processes started for task %d", task_id) 

313 return None