Coverage for yaptide/batch/batch_methods.py: 16%
201 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import io
2import json
3import logging
4import os
5import tempfile
6import requests
7from datetime import datetime
8from pathlib import Path
9from zipfile import ZipFile
10import sqlalchemy as db
12import pymchelper
13from fabric import Connection, Result
14from paramiko import RSAKey
16from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT)
17from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header)
18from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel)
19from yaptide.utils.enums import EntityState
20from yaptide.utils.sim_utils import write_simulation_input_files
22from yaptide.admin.db_manage import TableTypes, connect_to_db
23from yaptide.utils.helper_worker import celery_app
26def get_user(db_con, metadata, userId):
27 """Queries database for user"""
28 users = metadata.tables[TableTypes.User.name]
29 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name]
30 stmt = db.select(users,
31 keycloackUsers).select_from(users).join(keycloackUsers,
32 KeycloakUserModel.id == UserModel.id).filter_by(id=userId)
33 try:
34 user: KeycloakUserModel = db_con.execute(stmt).first()
35 except Exception:
36 logging.error('Error getting user object wiht id: %s from database', str(userId))
37 return None
38 return user
41def get_cluster(db_con, metadata, clusterId):
42 """Queries database for user"""
43 clusters = metadata.tables[TableTypes.Cluster.name]
44 stmt = db.select(clusters).filter_by(id=clusterId)
45 try:
46 cluster: ClusterModel = db_con.execute(stmt).first()
47 except Exception:
48 logging.error('Error getting cluster object with id: %s from database', str(clusterId))
49 return None
50 return cluster
53def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection:
54 """Returns connection object to cluster"""
55 pkey = RSAKey.from_private_key(io.StringIO(user.private_key))
56 pkey.load_certificate(user.cert)
58 con = Connection(host=f"{user.username}@{cluster.cluster_name}",
59 connect_kwargs={
60 "pkey": pkey,
61 "allow_agent": False,
62 "look_for_keys": False
63 })
64 return con
67def post_update(dict_to_send):
68 """For sending requests with information to flask"""
69 flask_url = os.environ.get("BACKEND_INTERNAL_URL")
70 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send)
73@celery_app.task()
74def submit_job( # skipcq: PY-R1000
75 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str):
76 """Submits job to cluster"""
77 utc_now = int(datetime.utcnow().timestamp() * 1e6)
78 try:
79 db_con, metadata, _ = connect_to_db(
80 ) # Connection to database and quering objects looks like that because celery task works outside flask context
81 except Exception as e:
82 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e))
84 user = get_user(db_con=db_con, metadata=metadata, userId=userId)
85 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId)
87 if user.cert is None or user.private_key is None:
88 dict_to_send = {
89 "sim_id": sim_id,
90 "job_state": EntityState.FAILED.value,
91 "update_key": update_key,
92 "log": {
93 "error": f"User {user.username} has no certificate or private key"
94 }
95 }
96 post_update(dict_to_send)
97 return
99 try:
100 con = get_connection(user=user, cluster=cluster)
101 fabric_result: Result = con.run("echo $SCRATCH", hide=True)
102 except Exception as e:
103 dict_to_send = {
104 "sim_id": sim_id,
105 "job_state": EntityState.FAILED.value,
106 "update_key": update_key,
107 "log": {
108 "error": str(e)
109 }
110 }
111 post_update(dict_to_send)
112 return
114 scratch = fabric_result.stdout.split()[0]
115 logging.debug("Scratch directory: %s", scratch)
117 job_dir = f"{scratch}/yaptide_runs/{utc_now}"
118 logging.debug("Job directory: %s", job_dir)
120 try:
121 con.run(f"mkdir -p {job_dir}")
122 except Exception as e:
123 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "update_key": update_key}
124 post_update(dict_to_send)
125 return
126 with tempfile.TemporaryDirectory() as tmp_dir_path:
127 logging.debug("Preparing simulation input in: %s", tmp_dir_path)
128 zip_path = Path(tmp_dir_path) / "input.zip"
129 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path))
130 logging.debug("Zipping simulation input to %s", zip_path)
131 with ZipFile(zip_path, mode="w") as archive:
132 for file in Path(tmp_dir_path).iterdir():
133 if file.name == "input.zip":
134 continue
135 archive.write(file, arcname=file.name)
136 con.put(zip_path, job_dir)
137 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir)
139 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py"
140 SIMULATION_DATA_SENDER_SCRIPT = Path(__file__).parent.resolve() / "simulation_data_sender.py"
142 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir)
143 con.put(WATCHER_SCRIPT, job_dir)
144 logging.debug("Transfering result sender script %s to %s", SIMULATION_DATA_SENDER_SCRIPT, job_dir)
145 con.put(SIMULATION_DATA_SENDER_SCRIPT, job_dir)
147 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict,
148 job_dir=job_dir,
149 sim_id=sim_id,
150 update_key=update_key,
151 con=con)
153 array_id = collect_id = None
154 if not submit_file.startswith(job_dir):
155 logging.error("Invalid submit file path: %s", submit_file)
156 dict_to_send = {
157 "sim_id": sim_id,
158 "job_state": EntityState.FAILED.value,
159 "update_key": update_key,
160 "log": {
161 "error": "Job submission failed due to invalid submit file path"
162 }
163 }
164 post_update(dict_to_send)
165 return
166 fabric_result: Result = con.run(f'sh {submit_file}', hide=True)
167 submit_stdout = fabric_result.stdout
168 submit_stderr = fabric_result.stderr
169 for line in submit_stdout.split("\n"):
170 if line.startswith("Job id"):
171 try:
172 array_id = int(line.split()[-1])
173 except (ValueError, IndexError):
174 logging.error("Could not parse array id from line: %s", line)
175 if line.startswith("Collect id"):
176 try:
177 collect_id = int(line.split()[-1])
178 except (ValueError, IndexError):
179 logging.error("Could not parse collect id from line: %s", line)
181 if array_id is None or collect_id is None:
182 logging.debug("Job submission failed")
183 logging.debug("Sbatch stdout: %s", submit_stdout)
184 logging.debug("Sbatch stderr: %s", submit_stderr)
185 dict_to_send = {
186 "sim_id": sim_id,
187 "job_state": EntityState.FAILED.value,
188 "update_key": update_key,
189 "log": {
190 "message": "Job submission failed",
191 "submit_stdout": submit_stdout,
192 "sh_files": sh_files,
193 "submit_stderr": submit_stderr
194 }
195 }
196 post_update(dict_to_send)
197 return
199 dict_to_send = {
200 "sim_id": sim_id,
201 "update_key": update_key,
202 "job_dir": job_dir,
203 "array_id": array_id,
204 "collect_id": collect_id,
205 "submit_stdout": submit_stdout,
206 "sh_files": sh_files
207 }
208 post_update(dict_to_send)
209 return
212def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str,
213 con: Connection) -> tuple[str, dict]:
214 """Prepares script files to run them on cluster"""
215 submit_file = f'{job_dir}/yaptide_submitter.sh'
216 array_file = f'{job_dir}/array_script.sh'
217 collect_file = f'{job_dir}/collect_script.sh'
219 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options")
220 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header")
222 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options")
223 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header")
225 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "")
227 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options,
228 collect_options=collect_options,
229 root_dir=job_dir,
230 n_tasks=str(payload_dict["ntasks"]),
231 convertmc_version=pymchelper.__version__)
232 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header,
233 root_dir=job_dir,
234 sim_id=sim_id,
235 update_key=update_key,
236 backend_url=backend_url)
237 collect_script = COLLECT_BASH.format(collect_header=collect_header,
238 root_dir=job_dir,
239 clear_bdos="true",
240 sim_id=sim_id,
241 update_key=update_key,
242 backend_url=backend_url)
244 con.run(f'echo \'{array_script}\' >> {array_file}')
245 con.run(f'chmod +x {array_file}')
246 con.run(f'echo \'{submit_script}\' >> {submit_file}')
247 con.run(f'chmod +x {submit_file}')
248 con.run(f'echo \'{collect_script}\' >> {collect_file}')
249 con.run(f'chmod +x {collect_file}')
251 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script}
254def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
255 """Get SLURM job status"""
256 array_id = simulation.array_id
257 collect_id = simulation.collect_id
259 con = get_connection(user=user, cluster=cluster)
261 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True)
262 job_state = fabric_result.stdout.split()[-1].split()[0]
264 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
265 collect_state = fabric_result.stdout.split()[-1].split()[0]
267 if job_state == "FAILED" or collect_state == "FAILED":
268 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
269 if collect_state == "COMPLETED":
270 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
271 if collect_state == "RUNNING":
272 logging.debug("Collect job is in RUNNING state")
273 return {"job_state": EntityState.MERGING_RUNNING.value}
274 if job_state == "COMPLETED" and collect_state == "PENDING":
275 logging.debug("Collect job is in PENDING state")
276 return {"job_state": EntityState.MERGING_QUEUED.value}
277 if job_state == "RUNNING":
278 logging.debug("Main job is in RUNNING state")
279 if job_state == "PENDING":
280 logging.debug("Main job is in PENDING state")
282 return {"job_state": EntityState.RUNNING.value}
285def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
286 """Returns simulation results"""
287 job_dir = simulation.job_dir
288 collect_id = simulation.collect_id
290 con = get_connection(user=user, cluster=cluster)
292 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
293 collect_state = fabric_result.stdout.split()[-1].split()[0]
295 if collect_state == "COMPLETED":
296 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True)
297 result_estimators = []
298 with tempfile.TemporaryDirectory() as tmp_dir_path:
299 for filename in fabric_result.stdout.split():
300 file_path = Path(tmp_dir_path, filename)
301 with open(file_path, "wb") as writer:
302 con.get(f'{job_dir}/output/{filename}', writer)
303 with open(file_path, "r") as json_file:
304 est_dict = json.load(json_file)
305 est_dict["name"] = filename.split('.')[0]
306 result_estimators.append(est_dict)
308 return {"estimators": result_estimators}
309 return {"message": "Results not available"}
312def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel,
313 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613
314 """Dummy version of delete_job"""
315 job_dir = simulation.job_dir
316 array_id = simulation.array_id
317 collect_id = simulation.collect_id
319 try:
320 con = get_connection(user=user, cluster=cluster)
322 con.run(f'scancel {array_id}')
323 con.run(f'scancel {collect_id}')
324 con.run(f'rm -rf {job_dir}')
325 except Exception as e: # skipcq: PYL-W0703
326 logging.error(e)
327 return {"message": "Job cancelation failed"}, 500
329 return {"message": "Job canceled"}, 200