Coverage for yaptide/celery/tasks.py: 90%
152 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import contextlib
2from dataclasses import dataclass
3import logging
4import tempfile
5from datetime import datetime
6from pathlib import Path
7import threading
8from typing import Optional
10from yaptide.batch.batch_methods import post_update
11from yaptide.celery.utils.pymc import (average_estimators, command_to_run_fluka, command_to_run_shieldhit,
12 execute_simulation_subprocess, get_fluka_estimators, get_shieldhit_estimators,
13 get_tmp_dir, read_file, read_file_offline, read_fluka_file)
14from yaptide.celery.utils.requests import (send_simulation_logfiles, send_simulation_results, send_task_update)
15from yaptide.celery.simulation_worker import celery_app
16from yaptide.utils.enums import EntityState
17from yaptide.utils.sim_utils import (check_and_convert_payload_to_files_dict, estimators_to_list, simulation_logfiles,
18 write_simulation_input_files)
21@celery_app.task
22def convert_input_files(payload_dict: dict) -> dict:
23 """Function converting output"""
24 files_dict = check_and_convert_payload_to_files_dict(payload_dict=payload_dict)
25 return {"input_files": files_dict}
28@celery_app.task(bind=True)
29def run_single_simulation(self,
30 files_dict: dict,
31 task_id: int,
32 update_key: str = '',
33 simulation_id: int = None,
34 keep_tmp_files: bool = False,
35 sim_type: str = 'shieldhit') -> dict:
36 """Function running single simulation"""
37 # for the purpose of running this function in pytest we would like to have some control
38 # on the temporary directory used by the function
40 logging.info("Running simulation, simulation_id: %s, task_id: %d", simulation_id, task_id)
42 logging.info("Sending initial update for task %d, setting celery id %s", task_id, self.request.id)
43 send_task_update(simulation_id, task_id, update_key, {"celery_id": self.request.id})
45 # we would like to have some control on the temporary directory used by the function
46 tmp_dir = get_tmp_dir()
47 logging.info("Temporary directory is: %s", tmp_dir)
49 # with tempfile.TemporaryDirectory(dir=tmp_dir) as tmp_dir_path:
50 # use the selected temporary directory to create a temporary directory
51 with (contextlib.nullcontext(tempfile.mkdtemp(dir=tmp_dir)) if keep_tmp_files else tempfile.TemporaryDirectory(
52 dir=tmp_dir)) as tmp_work_dir:
54 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_work_dir))
55 logging.debug("Generated input files: %s", files_dict.keys())
57 if sim_type == 'shieldhit':
58 simulation_result = run_single_simulation_for_shieldhit(tmp_work_dir, task_id, update_key, simulation_id)
59 elif sim_type == 'fluka':
60 simulation_result = run_single_simulation_for_fluka(tmp_work_dir, task_id, update_key, simulation_id)
62 # there is no simulation output
63 if not simulation_result.estimators_dict:
64 # first we notify the backend that the task with simulation has failed
65 logging.info("Simulation failed for task %d, sending update that it has failed", task_id)
66 update_dict = {"task_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
67 send_task_update(simulation_id, task_id, update_key, update_dict)
69 # then we send the logfiles to the backend, if available
70 logfiles = simulation_logfiles(path=Path(tmp_work_dir))
71 logging.info("Simulation failed, logfiles: %s", logfiles.keys())
72 # the method below is in particular broken,
73 # as there may be several logfiles, for some of the tasks
74 # lets imagine following sequence of actions:
75 # task 1 fails, with some usefule message in the logfile,
76 # i.e. after 100 primaries the SHIELD-HIT12A binary crashed
77 # then the useful logfiles are being sent to the backend
78 # task 2 fails later, but here the SHIELD-HIT12A binary crashes
79 # at the beginning of the simulation, without producing of the logfiles
80 # then again the logfiles are being sent to the backend, but this time they are empty
81 # so the useful logfiles are overwritten by the empty ones
82 # we temporarily disable sending logfiles to the backend
83 # if logfiles:
84 # pass
85 # sending_logfiles_status = send_simulation_logfiles(simulation_id=simulation_id,
86 # update_key=update_key,
87 # logfiles=logfiles)
88 # if not sending_logfiles_status:
89 # logging.error("Sending logfiles failed for task %s", task_id)
91 # finally we return from the celery task, returning the logfiles and stdout/stderr as result
92 return {
93 "logfiles": logfiles,
94 "stdout": simulation_result.command_stdout,
95 "stderr": simulation_result.command_stderr,
96 "simulation_id": simulation_id,
97 "update_key": update_key
98 }
100 # otherwise we have simulation output
101 logging.debug("Converting simulation results to JSON")
102 estimators = estimators_to_list(estimators_dict=simulation_result.estimators_dict, dir_path=Path(tmp_work_dir))
104 # We do not have any information if monitoring process sent the last update
105 # so we send it here to make sure that we have the end_time and COMPLETED state
106 end_time = datetime.utcnow().isoformat(sep=" ")
107 update_dict = {
108 "task_state": EntityState.COMPLETED.value,
109 "end_time": end_time,
110 "simulated_primaries": simulation_result.requested_primaries,
111 "requested_primaries": simulation_result.requested_primaries
112 }
113 send_task_update(simulation_id, task_id, update_key, update_dict)
115 # finally return from the celery task, returning the estimators and stdout/stderr as result
116 # the estimators will be merged by subsequent celery task
117 return {"estimators": estimators, "simulation_id": simulation_id, "update_key": update_key}
120@dataclass
121class SimulationTaskResult:
122 """Class representing result of single simulation task"""
124 process_exit_success: bool
125 command_stdout: str
126 command_stderr: str
127 simulated_primaries: int
128 requested_primaries: int
129 estimators_dict: dict
132def run_single_simulation_for_shieldhit(tmp_work_dir: str,
133 task_id: int,
134 update_key: str = '',
135 simulation_id: int = Optional[None]) -> SimulationTaskResult:
136 """Function running single simulation for shieldhit"""
137 command_as_list = command_to_run_shieldhit(dir_path=Path(tmp_work_dir), task_id=task_id)
138 logging.info("Command to run SHIELD-HIT12A: %s", " ".join(command_as_list))
140 command_stdout, command_stderr = '', ''
141 simulated_primaries, requested_primaries = 0, 0
142 event = threading.Event()
144 # start monitoring process if possible
145 # is None if monitoring if monitor was not started
146 task_monitor = monitor_shieldhit(event, tmp_work_dir, task_id, update_key, simulation_id)
147 # run the simulation
148 logging.info("Running SHIELD-HIT12A process in %s", tmp_work_dir)
149 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess(
150 dir_path=Path(tmp_work_dir), command_as_list=command_as_list)
151 logging.info("SHIELD-HIT12A process finished with status %s", process_exit_success)
153 # terminate monitoring process
154 if task_monitor:
155 logging.debug("Terminating monitoring process for task %d", task_id)
156 event.set()
157 task_monitor.task.join()
158 logging.debug("Monitoring process for task %d terminated", task_id)
159 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend
160 if task_monitor:
161 simulated_primaries, requested_primaries = read_file_offline(task_monitor.path_to_monitor)
163 # both simulation execution and monitoring process are finished now, we can read the estimators
164 estimators_dict = get_shieldhit_estimators(dir_path=Path(tmp_work_dir))
166 return SimulationTaskResult(process_exit_success=process_exit_success,
167 command_stdout=command_stdout,
168 command_stderr=command_stderr,
169 simulated_primaries=simulated_primaries,
170 requested_primaries=requested_primaries,
171 estimators_dict=estimators_dict)
174def run_single_simulation_for_fluka(tmp_work_dir: str,
175 task_id: int,
176 update_key: str = '',
177 simulation_id: Optional[int] = None) -> SimulationTaskResult:
178 """Function running single simulation for shieldhit"""
179 command_as_list = command_to_run_fluka(dir_path=Path(tmp_work_dir), task_id=task_id)
180 logging.info("Command to run FLUKA: %s", " ".join(command_as_list))
182 command_stdout, command_stderr = '', ''
183 simulated_primaries, requested_primaries = 0, 0
184 event = threading.Event()
185 # start monitoring process if possible
186 # is None if monitoring if monitor was not started
187 task_monitor = monitor_fluka(event, tmp_work_dir, task_id, update_key, simulation_id)
189 # run the simulation
190 logging.info("Running Fluka process in %s", tmp_work_dir)
191 process_exit_success, command_stdout, command_stderr = execute_simulation_subprocess(
192 dir_path=Path(tmp_work_dir), command_as_list=command_as_list)
193 logging.info("Fluka process finished with status %s", process_exit_success)
195 # terminate monitoring process
196 if task_monitor:
197 logging.debug("Terminating monitoring process for task %s", task_id)
198 event.set()
199 task_monitor.task.join()
200 logging.debug("Monitoring process for task %s terminated", task_id)
201 # TO BE IMPLEMENTED
202 # if watcher didn't finish yet, we need to read the log file and send the last update to the backend
203 # reading of the log file for fluka after simulation was finished
204 # fluka copies the file back to main directory from temporary directory
206 # both simulation execution and monitoring process are finished now, we can read the estimators
207 estimators_dict = get_fluka_estimators(dir_path=Path(tmp_work_dir))
209 return SimulationTaskResult(process_exit_success=process_exit_success,
210 command_stdout=command_stdout,
211 command_stderr=command_stderr,
212 simulated_primaries=simulated_primaries,
213 requested_primaries=requested_primaries,
214 estimators_dict=estimators_dict)
217@celery_app.task
218def set_merging_queued_state(results: list[dict]) -> list[dict]:
219 """Celery task to set simulation state as MERGING_QUEUED"""
220 logging.debug("send_state")
221 simulation_id = results[0].get("simulation_id", None)
222 update_key = results[0].get("update_key", None)
223 if simulation_id and update_key:
224 dict_to_send = {
225 "sim_id": simulation_id,
226 "job_state": EntityState.MERGING_QUEUED.value,
227 "update_key": update_key
228 }
229 post_update(dict_to_send)
230 return results
233@celery_app.task
234def merge_results(results: list[dict]) -> dict:
235 """Merge results from multiple simulation's tasks"""
236 logging.debug("Merging results from %d tasks", len(results))
237 logfiles = {}
239 averaged_estimators = None
240 simulation_id = results[0].pop("simulation_id", None)
241 update_key = results[0].pop("update_key", None)
242 if simulation_id and update_key:
243 dict_to_send = {
244 "sim_id": simulation_id,
245 "job_state": EntityState.MERGING_RUNNING.value,
246 "update_key": update_key
247 }
248 post_update(dict_to_send)
249 for i, result in enumerate(results):
250 if simulation_id is None:
251 simulation_id = result.pop("simulation_id", None)
252 if update_key is None:
253 update_key = result.pop("update_key", None)
254 if "logfiles" in result:
255 logfiles.update(result["logfiles"])
256 continue
258 if averaged_estimators is None:
259 averaged_estimators: list[dict] = result.get("estimators", [])
260 # There is nothing to average yet
261 continue
263 averaged_estimators = average_estimators(averaged_estimators, result.get("estimators", []), i)
265 final_result = {"end_time": datetime.utcnow().isoformat(sep=" ")}
267 if len(logfiles.keys()) > 0 and not send_simulation_logfiles(
268 simulation_id=simulation_id, update_key=update_key, logfiles=logfiles):
269 final_result["logfiles"] = logfiles
271 if averaged_estimators:
272 # send results to the backend and mark whole simulation as completed
273 sending_results_ok = send_simulation_results(simulation_id=simulation_id,
274 update_key=update_key,
275 estimators=averaged_estimators)
276 if not sending_results_ok:
277 final_result["estimators"] = averaged_estimators
279 return final_result
282@dataclass
283class MonitorTask:
284 """Class representing monitoring task"""
286 path_to_monitor: Path
287 task: threading.Thread
290def monitor_shieldhit(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str,
291 simulation_id: str) -> Optional[MonitorTask]:
292 """Function monitoring progress of SHIELD-HIT12A simulation"""
293 # we would like to monitor the progress of simulation
294 # this is done by reading the log file and sending the updates to the backend
295 # if we have update_key and simulation_id the monitoring task can submit the updates to backend
296 path_to_monitor = Path(tmp_work_dir) / f"shieldhit_{task_id:04d}.log"
297 if update_key and simulation_id is not None:
298 current_logging_level = logging.getLogger().getEffectiveLevel()
299 task = threading.Thread(target=read_file,
300 kwargs=dict(event=event,
301 filepath=path_to_monitor,
302 simulation_id=simulation_id,
303 task_id=task_id,
304 update_key=update_key,
305 logging_level=current_logging_level))
306 task.start()
307 logging.info("Started monitoring process for task %d", task_id)
308 return MonitorTask(path_to_monitor=path_to_monitor, task=task)
310 logging.info("No monitoring processes started for task %d", task_id)
311 return None
314def monitor_fluka(event: threading.Event, tmp_work_dir: str, task_id: int, update_key: str,
315 simulation_id: int) -> Optional[MonitorTask]:
316 """Function running the monitoring process for Fluka simulation"""
317 # we would like to monitor the progress of simulation
318 # this is done by reading the log file and sending the updates to the backend
319 # if we have update_key and simulation_id the monitoring task can submit the updates to backend
320 # We use dir instead path, because fluka simulator generates direcoty with PID in name of its process
321 dir_to_monitor = Path(tmp_work_dir)
322 if update_key and simulation_id is not None:
323 current_logging_level = logging.getLogger().getEffectiveLevel()
324 task = threading.Thread(target=read_fluka_file,
325 kwargs=dict(event=event,
326 dirpath=dir_to_monitor,
327 simulation_id=simulation_id,
328 task_id=task_id,
329 update_key=update_key,
330 logging_level=current_logging_level))
332 task.start()
333 logging.info("Started monitoring process for task %d", task_id)
334 return MonitorTask(path_to_monitor=dir_to_monitor, task=task)
336 logging.info("No monitoring processes started for task %d", task_id)
337 return None