Coverage for yaptide/batch/batch_methods.py: 15%
199 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
1import io
2import json
3import logging
4import os
5import tempfile
6import requests
7from datetime import datetime
8from pathlib import Path
9from zipfile import ZipFile
10import sqlalchemy as db
12import pymchelper
13from fabric import Connection, Result
14from paramiko import RSAKey
16from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT)
17from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header)
18from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel)
19from yaptide.utils.enums import EntityState
20from yaptide.utils.sim_utils import write_simulation_input_files
22from yaptide.admin.db_manage import TableTypes, connect_to_db
23from yaptide.utils.helper_worker import celery_app
26def get_user(db_con, metadata, userId):
27 """Queries database for user"""
28 users = metadata.tables[TableTypes.User.name]
29 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name]
30 stmt = db.select(users,
31 keycloackUsers).select_from(users).join(keycloackUsers,
32 KeycloakUserModel.id == UserModel.id).filter_by(id=userId)
33 try:
34 user: KeycloakUserModel = db_con.execute(stmt).first()
35 except Exception:
36 logging.error('Error getting user object wiht id: %s from database', str(userId))
37 return None
38 return user
41def get_cluster(db_con, metadata, clusterId):
42 """Queries database for user"""
43 clusters = metadata.tables[TableTypes.Cluster.name]
44 stmt = db.select(clusters).filter_by(id=clusterId)
45 try:
46 cluster: ClusterModel = db_con.execute(stmt).first()
47 except Exception:
48 logging.error('Error getting cluster object with id: %s from database', str(clusterId))
49 return None
50 return cluster
53def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection:
54 """Returns connection object to cluster"""
55 pkey = RSAKey.from_private_key(io.StringIO(user.private_key))
56 pkey.load_certificate(user.cert)
58 con = Connection(host=f"{user.username}@{cluster.cluster_name}",
59 connect_kwargs={
60 "pkey": pkey,
61 "allow_agent": False,
62 "look_for_keys": False
63 })
64 return con
67def post_update(dict_to_send):
68 """For sending requests with information to flask"""
69 flask_url = os.environ.get("BACKEND_INTERNAL_URL")
70 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send)
73@celery_app.task()
74def submit_job( # skipcq: PY-R1000
75 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str):
76 """Submits job to cluster"""
77 utc_now = int(datetime.utcnow().timestamp() * 1e6)
78 try:
79 db_con, metadata, _ = connect_to_db(
80 ) # Connection to database and quering objects looks like that because celery task works outside flask context
81 except Exception as e:
82 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e))
84 user = get_user(db_con=db_con, metadata=metadata, userId=userId)
85 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId)
87 if user.cert is None or user.private_key is None:
88 dict_to_send = {
89 "sim_id": sim_id,
90 "job_state": EntityState.FAILED.value,
91 "log": {
92 "error": f"User {user.username} has no certificate or private key"
93 }
94 }
95 post_update(dict_to_send)
96 return
98 try:
99 con = get_connection(user=user, cluster=cluster)
100 fabric_result: Result = con.run("echo $SCRATCH", hide=True)
101 except Exception as e:
102 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "log": {"error": str(e)}}
103 post_update(dict_to_send)
104 return
106 scratch = fabric_result.stdout.split()[0]
107 logging.debug("Scratch directory: %s", scratch)
109 job_dir = f"{scratch}/yaptide_runs/{utc_now}"
110 logging.debug("Job directory: %s", job_dir)
112 try:
113 con.run(f"mkdir -p {job_dir}")
114 except Exception as e:
115 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value}
116 post_update(dict_to_send)
117 return
118 with tempfile.TemporaryDirectory() as tmp_dir_path:
119 logging.debug("Preparing simulation input in: %s", tmp_dir_path)
120 zip_path = Path(tmp_dir_path) / "input.zip"
121 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path))
122 logging.debug("Zipping simulation input to %s", zip_path)
123 with ZipFile(zip_path, mode="w") as archive:
124 for file in Path(tmp_dir_path).iterdir():
125 if file.name == "input.zip":
126 continue
127 archive.write(file, arcname=file.name)
128 con.put(zip_path, job_dir)
129 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir)
131 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py"
132 RESULT_SENDER_SCRIPT = Path(__file__).parent.resolve() / "result_sender.py"
134 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir)
135 con.put(WATCHER_SCRIPT, job_dir)
136 logging.debug("Transfering result sender script %s to %s", RESULT_SENDER_SCRIPT, job_dir)
137 con.put(RESULT_SENDER_SCRIPT, job_dir)
139 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict,
140 job_dir=job_dir,
141 sim_id=sim_id,
142 update_key=update_key,
143 con=con)
145 array_id = collect_id = None
146 if not submit_file.startswith(job_dir):
147 logging.error("Invalid submit file path: %s", submit_file)
148 dict_to_send = {
149 "sim_id": sim_id,
150 "job_state": EntityState.FAILED.value,
151 "log": {
152 "error": "Job submission failed due to invalid submit file path"
153 }
154 }
155 post_update(dict_to_send)
156 return
157 fabric_result: Result = con.run(f'sh {submit_file}', hide=True)
158 submit_stdout = fabric_result.stdout
159 submit_stderr = fabric_result.stderr
160 for line in submit_stdout.split("\n"):
161 if line.startswith("Job id"):
162 try:
163 array_id = int(line.split()[-1])
164 except (ValueError, IndexError):
165 logging.error("Could not parse array id from line: %s", line)
166 if line.startswith("Collect id"):
167 try:
168 collect_id = int(line.split()[-1])
169 except (ValueError, IndexError):
170 logging.error("Could not parse collect id from line: %s", line)
172 if array_id is None or collect_id is None:
173 logging.debug("Job submission failed")
174 logging.debug("Sbatch stdout: %s", submit_stdout)
175 logging.debug("Sbatch stderr: %s", submit_stderr)
176 dict_to_send = {
177 "sim_id": sim_id,
178 "job_state": EntityState.FAILED.value,
179 "log": {
180 "message": "Job submission failed",
181 "submit_stdout": submit_stdout,
182 "sh_files": sh_files,
183 "submit_stderr": submit_stderr
184 }
185 }
186 post_update(dict_to_send)
187 return
189 dict_to_send = {
190 "sim_id": sim_id,
191 "job_dir": job_dir,
192 "array_id": array_id,
193 "collect_id": collect_id,
194 "submit_stdout": submit_stdout,
195 "sh_files": sh_files
196 }
197 post_update(dict_to_send)
198 return
201def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str,
202 con: Connection) -> tuple[str, dict]:
203 """Prepares script files to run them on cluster"""
204 submit_file = f'{job_dir}/yaptide_submitter.sh'
205 array_file = f'{job_dir}/array_script.sh'
206 collect_file = f'{job_dir}/collect_script.sh'
208 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options")
209 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header")
211 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options")
212 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header")
214 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "")
216 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options,
217 collect_options=collect_options,
218 root_dir=job_dir,
219 n_tasks=str(payload_dict["ntasks"]),
220 convertmc_version=pymchelper.__version__)
221 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header,
222 root_dir=job_dir,
223 sim_id=sim_id,
224 update_key=update_key,
225 backend_url=backend_url)
226 collect_script = COLLECT_BASH.format(collect_header=collect_header,
227 root_dir=job_dir,
228 clear_bdos="true",
229 sim_id=sim_id,
230 update_key=update_key,
231 backend_url=backend_url)
233 con.run(f'echo \'{array_script}\' >> {array_file}')
234 con.run(f'chmod +x {array_file}')
235 con.run(f'echo \'{submit_script}\' >> {submit_file}')
236 con.run(f'chmod +x {submit_file}')
237 con.run(f'echo \'{collect_script}\' >> {collect_file}')
238 con.run(f'chmod +x {collect_file}')
240 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script}
243def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
244 """Get SLURM job status"""
245 array_id = simulation.array_id
246 collect_id = simulation.collect_id
248 con = get_connection(user=user, cluster=cluster)
250 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True)
251 job_state = fabric_result.stdout.split()[-1].split()[0]
253 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
254 collect_state = fabric_result.stdout.split()[-1].split()[0]
256 if job_state == "FAILED" or collect_state == "FAILED":
257 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
258 if collect_state == "COMPLETED":
259 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
260 if collect_state == "RUNNING":
261 logging.debug("Collect job is in RUNNING state")
262 if job_state == "RUNNING":
263 logging.debug("Main job is in RUNNING state")
264 if collect_state == "PENDING":
265 logging.debug("Collect job is in PENDING state")
266 if job_state == "PENDING":
267 logging.debug("Main job is in PENDING state")
269 return {"job_state": EntityState.RUNNING.value}
272def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
273 """Returns simulation results"""
274 job_dir = simulation.job_dir
275 collect_id = simulation.collect_id
277 con = get_connection(user=user, cluster=cluster)
279 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
280 collect_state = fabric_result.stdout.split()[-1].split()[0]
282 if collect_state == "COMPLETED":
283 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True)
284 result_estimators = []
285 with tempfile.TemporaryDirectory() as tmp_dir_path:
286 for filename in fabric_result.stdout.split():
287 file_path = Path(tmp_dir_path, filename)
288 with open(file_path, "wb") as writer:
289 con.get(f'{job_dir}/output/{filename}', writer)
290 with open(file_path, "r") as json_file:
291 est_dict = json.load(json_file)
292 est_dict["name"] = filename.split('.')[0]
293 result_estimators.append(est_dict)
295 return {"estimators": result_estimators}
296 return {"message": "Results not available"}
299def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel,
300 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613
301 """Dummy version of delete_job"""
302 job_dir = simulation.job_dir
303 array_id = simulation.array_id
304 collect_id = simulation.collect_id
306 try:
307 con = get_connection(user=user, cluster=cluster)
309 con.run(f'scancel {array_id}')
310 con.run(f'scancel {collect_id}')
311 con.run(f'rm -rf {job_dir}')
312 except Exception as e: # skipcq: PYL-W0703
313 logging.error(e)
314 return {"message": "Job cancelation failed"}, 500
316 return {"message": "Job canceled"}, 200