Coverage for yaptide/batch/batch_methods.py: 15%
213 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-03-31 19:18 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-03-31 19:18 +0000
1import io
2import json
3import logging
4import os
5import tempfile
6import requests
7from datetime import datetime
8from pathlib import Path
9from zipfile import ZipFile
10import sqlalchemy as db
12import pymchelper
13from fabric import Connection, Result
14from paramiko import RSAKey
16from yaptide.batch.shieldhit_string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_SHIELDHIT_BASH, SUBMIT_SHIELDHIT)
17from yaptide.batch.fluka_string_templates import (ARRAY_FLUKA_BASH, COLLECT_FLUKA_BASH, SUBMIT_FLUKA)
18from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header)
19from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel)
20from yaptide.utils.enums import EntityState, SimulationType
21from yaptide.utils.sim_utils import write_simulation_input_files
23from yaptide.admin.db_manage import TableTypes, connect_to_db
24from yaptide.utils.helper_worker import celery_app
27def get_user(db_con, metadata, userId):
28 """Queries database for user"""
29 users = metadata.tables[TableTypes.User.name]
30 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name]
31 stmt = db.select(users,
32 keycloackUsers).select_from(users).join(keycloackUsers,
33 KeycloakUserModel.id == UserModel.id).filter_by(id=userId)
34 try:
35 user: KeycloakUserModel = db_con.execute(stmt).first()
36 except Exception:
37 logging.error('Error getting user object wiht id: %s from database', str(userId))
38 return None
39 return user
42def get_cluster(db_con, metadata, clusterId):
43 """Queries database for user"""
44 clusters = metadata.tables[TableTypes.Cluster.name]
45 stmt = db.select(clusters).filter_by(id=clusterId)
46 try:
47 cluster: ClusterModel = db_con.execute(stmt).first()
48 except Exception:
49 logging.error('Error getting cluster object with id: %s from database', str(clusterId))
50 return None
51 return cluster
54def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection:
55 """Returns connection object to cluster"""
56 pkey = RSAKey.from_private_key(io.StringIO(user.private_key))
57 pkey.load_certificate(user.cert)
59 con = Connection(host=f"{user.username}@{cluster.cluster_name}",
60 connect_kwargs={
61 "pkey": pkey,
62 "allow_agent": False,
63 "look_for_keys": False
64 })
65 return con
68def post_update(dict_to_send):
69 """For sending requests with information to flask"""
70 flask_url = os.environ.get("BACKEND_INTERNAL_URL")
71 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send)
74@celery_app.task()
75def submit_job( # skipcq: PY-R1000
76 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str):
77 """Submits job to cluster"""
78 utc_now = int(datetime.utcnow().timestamp() * 1e6)
79 try:
80 db_con, metadata, _ = connect_to_db(
81 ) # Connection to database and quering objects looks like that because celery task works outside flask context
82 except Exception as e:
83 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e))
85 user = get_user(db_con=db_con, metadata=metadata, userId=userId)
86 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId)
88 if user.cert is None or user.private_key is None:
89 dict_to_send = {
90 "sim_id": sim_id,
91 "job_state": EntityState.FAILED.value,
92 "update_key": update_key,
93 "log": {
94 "error": f"User {user.username} has no certificate or private key"
95 }
96 }
97 post_update(dict_to_send)
98 return
100 try:
101 con = get_connection(user=user, cluster=cluster)
102 fabric_result: Result = con.run("echo $SCRATCH", hide=True)
103 except Exception as e:
104 dict_to_send = {
105 "sim_id": sim_id,
106 "job_state": EntityState.FAILED.value,
107 "update_key": update_key,
108 "log": {
109 "error": str(e)
110 }
111 }
112 post_update(dict_to_send)
113 return
115 scratch = fabric_result.stdout.split()[0]
116 logging.debug("Scratch directory: %s", scratch)
118 job_dir = f"{scratch}/yaptide_runs/{utc_now}"
119 logging.debug("Job directory: %s", job_dir)
121 try:
122 con.run(f"mkdir -p {job_dir}")
123 except Exception as e:
124 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "update_key": update_key}
125 post_update(dict_to_send)
126 return
127 with tempfile.TemporaryDirectory() as tmp_dir_path:
128 logging.debug("Preparing simulation input in: %s", tmp_dir_path)
129 zip_path = Path(tmp_dir_path) / "input.zip"
130 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path))
131 logging.debug("Zipping simulation input to %s", zip_path)
132 with ZipFile(zip_path, mode="w") as archive:
133 for file in Path(tmp_dir_path).iterdir():
134 if file.name == "input.zip":
135 continue
136 archive.write(file, arcname=file.name)
137 con.put(zip_path, job_dir)
138 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir)
140 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py"
141 SIMULATION_DATA_SENDER_SCRIPT = Path(__file__).parent.resolve() / "simulation_data_sender.py"
143 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir)
144 con.put(WATCHER_SCRIPT, job_dir)
145 logging.debug("Transfering result sender script %s to %s", SIMULATION_DATA_SENDER_SCRIPT, job_dir)
146 con.put(SIMULATION_DATA_SENDER_SCRIPT, job_dir)
148 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict,
149 job_dir=job_dir,
150 sim_id=sim_id,
151 update_key=update_key,
152 con=con)
154 array_id = collect_id = None
155 if not submit_file.startswith(job_dir):
156 logging.error("Invalid submit file path: %s", submit_file)
157 dict_to_send = {
158 "sim_id": sim_id,
159 "job_state": EntityState.FAILED.value,
160 "update_key": update_key,
161 "log": {
162 "error": "Job submission failed due to invalid submit file path"
163 }
164 }
165 post_update(dict_to_send)
166 return
167 fabric_result: Result = con.run(f'sh {submit_file}', hide=True)
168 submit_stdout = fabric_result.stdout
169 submit_stderr = fabric_result.stderr
170 for line in submit_stdout.split("\n"):
171 if line.startswith("Job id"):
172 try:
173 array_id = int(line.split()[-1])
174 except (ValueError, IndexError):
175 logging.error("Could not parse array id from line: %s", line)
176 if line.startswith("Collect id"):
177 try:
178 collect_id = int(line.split()[-1])
179 except (ValueError, IndexError):
180 logging.error("Could not parse collect id from line: %s", line)
182 if array_id is None or collect_id is None:
183 logging.debug("Job submission failed")
184 logging.debug("Sbatch stdout: %s", submit_stdout)
185 logging.debug("Sbatch stderr: %s", submit_stderr)
186 dict_to_send = {
187 "sim_id": sim_id,
188 "job_state": EntityState.FAILED.value,
189 "update_key": update_key,
190 "log": {
191 "message": "Job submission failed",
192 "submit_stdout": submit_stdout,
193 "sh_files": sh_files,
194 "submit_stderr": submit_stderr
195 }
196 }
197 post_update(dict_to_send)
198 return
200 dict_to_send = {
201 "sim_id": sim_id,
202 "update_key": update_key,
203 "job_dir": job_dir,
204 "array_id": array_id,
205 "collect_id": collect_id,
206 "submit_stdout": submit_stdout,
207 "sh_files": sh_files
208 }
209 post_update(dict_to_send)
210 return
213def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str,
214 con: Connection) -> tuple[str, dict]:
215 """Prepares script files to run them on cluster"""
216 submit_file = f'{job_dir}/yaptide_submitter.sh'
217 array_file = f'{job_dir}/array_script.sh'
218 collect_file = f'{job_dir}/collect_script.sh'
220 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options")
221 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header")
223 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options")
224 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header")
226 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "")
228 if payload_dict['sim_type'] == SimulationType.FLUKA.value:
229 submit_template = SUBMIT_FLUKA
230 array_template = ARRAY_FLUKA_BASH
231 collect_template = COLLECT_FLUKA_BASH
232 elif payload_dict['sim_type'] == SimulationType.SHIELDHIT.value:
233 submit_template = SUBMIT_SHIELDHIT
234 array_template = ARRAY_SHIELDHIT_BASH
235 collect_template = COLLECT_SHIELDHIT_BASH
236 else:
237 # Ready for future simulators
238 submit_template = ""
239 array_template = ""
240 collect_template = ""
242 submit_script = submit_template.format(array_options=array_options,
243 collect_options=collect_options,
244 root_dir=job_dir,
245 n_tasks=str(payload_dict["ntasks"]),
246 convertmc_version=pymchelper.__version__)
247 array_script = array_template.format(array_header=array_header,
248 root_dir=job_dir,
249 sim_id=sim_id,
250 update_key=update_key,
251 backend_url=backend_url)
252 collect_script = collect_template.format(collect_header=collect_header,
253 root_dir=job_dir,
254 remove_output_from_workspace="true",
255 sim_id=sim_id,
256 update_key=update_key,
257 backend_url=backend_url)
259 con.run(f'echo \'{array_script}\' >> {array_file}')
260 con.run(f'chmod +x {array_file}')
261 con.run(f'echo \'{submit_script}\' >> {submit_file}')
262 con.run(f'chmod +x {submit_file}')
263 con.run(f'echo \'{collect_script}\' >> {collect_file}')
264 con.run(f'chmod +x {collect_file}')
266 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script}
269def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
270 """Get SLURM job status"""
271 array_id = simulation.array_id
272 collect_id = simulation.collect_id
274 con = get_connection(user=user, cluster=cluster)
276 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True)
277 job_state = fabric_result.stdout.split()[-1].split()[0]
279 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
280 collect_state = fabric_result.stdout.split()[-1].split()[0]
282 if job_state == "FAILED" or collect_state == "FAILED":
283 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
284 if collect_state == "COMPLETED":
285 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
286 if collect_state == "RUNNING":
287 logging.debug("Collect job is in RUNNING state")
288 return {"job_state": EntityState.MERGING_RUNNING.value}
289 if job_state == "COMPLETED" and collect_state == "PENDING":
290 logging.debug("Collect job is in PENDING state")
291 return {"job_state": EntityState.MERGING_QUEUED.value}
292 if job_state == "RUNNING":
293 logging.debug("Main job is in RUNNING state")
294 if job_state == "PENDING":
295 logging.debug("Main job is in PENDING state")
297 return {"job_state": EntityState.RUNNING.value}
300def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
301 """Returns simulation results"""
302 job_dir = simulation.job_dir
303 collect_id = simulation.collect_id
305 con = get_connection(user=user, cluster=cluster)
307 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
308 collect_state = fabric_result.stdout.split()[-1].split()[0]
310 if collect_state == "COMPLETED":
311 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True)
312 result_estimators = []
313 with tempfile.TemporaryDirectory() as tmp_dir_path:
314 for filename in fabric_result.stdout.split():
315 file_path = Path(tmp_dir_path, filename)
316 with open(file_path, "wb") as writer:
317 con.get(f'{job_dir}/output/{filename}', writer)
318 with open(file_path, "r") as json_file:
319 est_dict = json.load(json_file)
320 est_dict["name"] = filename.split('.')[0]
321 result_estimators.append(est_dict)
323 return {"estimators": result_estimators}
324 return {"message": "Results not available"}
327def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel,
328 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613
329 """Dummy version of delete_job"""
330 job_dir = simulation.job_dir
331 array_id = simulation.array_id
332 collect_id = simulation.collect_id
334 try:
335 con = get_connection(user=user, cluster=cluster)
337 con.run(f'scancel {array_id}')
338 con.run(f'scancel {collect_id}')
339 con.run(f'rm -rf {job_dir}')
340 except Exception as e: # skipcq: PYL-W0703
341 logging.error(e)
342 return {"message": "Job cancelation failed"}, 500
344 return {"message": "Job canceled"}, 200