Coverage for yaptide/batch/batch_methods.py: 15%
145 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
1import io
2import json
3import logging
4import os
5import tempfile
6from datetime import datetime
7from pathlib import Path
8from zipfile import ZipFile
10import pymchelper
11from fabric import Connection, Result
12from paramiko import RSAKey
14from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT)
15from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header)
16from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel)
17from yaptide.utils.enums import EntityState
18from yaptide.utils.sim_utils import write_simulation_input_files
21def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection:
22 """Returns connection object to cluster"""
23 pkey = RSAKey.from_private_key(io.StringIO(user.private_key))
24 pkey.load_certificate(user.cert)
26 con = Connection(host=f"{user.username}@{cluster.cluster_name}",
27 connect_kwargs={
28 "pkey": pkey,
29 "allow_agent": False,
30 "look_for_keys": False
31 })
32 return con
35def submit_job(payload_dict: dict, files_dict: dict, user: KeycloakUserModel, cluster: ClusterModel, sim_id: int,
36 update_key: str) -> dict:
37 """Submits job to cluster"""
38 utc_now = int(datetime.utcnow().timestamp() * 1e6)
40 if user.cert is None or user.private_key is None:
41 return {"message": f"User {user.username} has no certificate or private key"}
42 con = get_connection(user=user, cluster=cluster)
44 fabric_result: Result = con.run("echo $SCRATCH", hide=True)
45 scratch = fabric_result.stdout.split()[0]
46 logging.debug("Scratch directory: %s", scratch)
48 job_dir = f"{scratch}/yaptide_runs/{utc_now}"
49 logging.debug("Job directory: %s", job_dir)
51 con.run(f"mkdir -p {job_dir}")
52 with tempfile.TemporaryDirectory() as tmp_dir_path:
53 logging.debug("Preparing simulation input in: %s", tmp_dir_path)
54 zip_path = Path(tmp_dir_path) / "input.zip"
55 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path))
56 logging.debug("Zipping simulation input to %s", zip_path)
57 with ZipFile(zip_path, mode="w") as archive:
58 for file in Path(tmp_dir_path).iterdir():
59 if file.name == "input.zip":
60 continue
61 archive.write(file, arcname=file.name)
62 con.put(zip_path, job_dir)
63 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir)
65 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py"
66 RESULT_SENDER_SCRIPT = Path(__file__).parent.resolve() / "result_sender.py"
68 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir)
69 con.put(WATCHER_SCRIPT, job_dir)
70 logging.debug("Transfering result sender script %s to %s", RESULT_SENDER_SCRIPT, job_dir)
71 con.put(RESULT_SENDER_SCRIPT, job_dir)
73 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict,
74 job_dir=job_dir,
75 sim_id=sim_id,
76 update_key=update_key,
77 con=con)
79 array_id = collect_id = None
80 fabric_result: Result = con.run(f'sh {submit_file}', hide=True)
81 submit_stdout = fabric_result.stdout
82 submit_stderr = fabric_result.stderr
83 for line in submit_stdout.split("\n"):
84 if line.startswith("Job id"):
85 try:
86 array_id = int(line.split()[-1])
87 except (ValueError, IndexError):
88 logging.error("Could not parse array id from line: %s", line)
89 if line.startswith("Collect id"):
90 try:
91 collect_id = int(line.split()[-1])
92 except (ValueError, IndexError):
93 logging.error("Could not parse collect id from line: %s", line)
95 if array_id is None or collect_id is None:
96 logging.debug("Job submission failed")
97 logging.debug("Sbatch stdout: %s", submit_stdout)
98 logging.debug("Sbatch stderr: %s", submit_stderr)
99 return {"message": "Job submission failed", "submit_stdout": submit_stdout, "sh_files": sh_files}
100 return {
101 "message": "Job submitted",
102 "job_dir": job_dir,
103 "array_id": array_id,
104 "collect_id": collect_id,
105 "submit_stdout": submit_stdout,
106 "sh_files": sh_files
107 }
110def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str,
111 con: Connection) -> tuple[str, dict]:
112 """Prepares script files to run them on cluster"""
113 submit_file = f'{job_dir}/yaptide_submitter.sh'
114 array_file = f'{job_dir}/array_script.sh'
115 collect_file = f'{job_dir}/collect_script.sh'
117 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options")
118 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header")
120 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options")
121 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header")
123 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "")
125 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options,
126 collect_options=collect_options,
127 root_dir=job_dir,
128 n_tasks=str(payload_dict["ntasks"]),
129 convertmc_version=pymchelper.__version__)
130 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header,
131 root_dir=job_dir,
132 sim_id=sim_id,
133 update_key=update_key,
134 backend_url=backend_url)
135 collect_script = COLLECT_BASH.format(collect_header=collect_header,
136 root_dir=job_dir,
137 clear_bdos="true",
138 sim_id=sim_id,
139 update_key=update_key,
140 backend_url=backend_url)
142 con.run(f'echo \'{array_script}\' >> {array_file}')
143 con.run(f'chmod +x {array_file}')
144 con.run(f'echo \'{submit_script}\' >> {submit_file}')
145 con.run(f'chmod +x {submit_file}')
146 con.run(f'echo \'{collect_script}\' >> {collect_file}')
147 con.run(f'chmod +x {collect_file}')
149 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script}
152def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
153 """Get SLURM job status"""
154 array_id = simulation.array_id
155 collect_id = simulation.collect_id
157 con = get_connection(user=user, cluster=cluster)
159 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True)
160 job_state = fabric_result.stdout.split()[-1].split()[0]
162 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
163 collect_state = fabric_result.stdout.split()[-1].split()[0]
165 if job_state == "FAILED" or collect_state == "FAILED":
166 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
167 if collect_state == "COMPLETED":
168 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
169 if collect_state == "RUNNING":
170 logging.debug("Collect job is in RUNNING state")
171 if job_state == "RUNNING":
172 logging.debug("Main job is in RUNNING state")
173 if collect_state == "PENDING":
174 logging.debug("Collect job is in PENDING state")
175 if job_state == "PENDING":
176 logging.debug("Main job is in PENDING state")
178 return {"job_state": EntityState.RUNNING.value}
181def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict:
182 """Returns simulation results"""
183 job_dir = simulation.job_dir
184 collect_id = simulation.collect_id
186 con = get_connection(user=user, cluster=cluster)
188 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True)
189 collect_state = fabric_result.stdout.split()[-1].split()[0]
191 if collect_state == "COMPLETED":
192 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True)
193 result_estimators = []
194 with tempfile.TemporaryDirectory() as tmp_dir_path:
195 for filename in fabric_result.stdout.split():
196 file_path = Path(tmp_dir_path, filename)
197 with open(file_path, "wb") as writer:
198 con.get(f'{job_dir}/output/{filename}', writer)
199 with open(file_path, "r") as json_file:
200 est_dict = json.load(json_file)
201 est_dict["name"] = filename.split('.')[0]
202 result_estimators.append(est_dict)
204 return {"estimators": result_estimators}
205 return {"message": "Results not available"}
208def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel,
209 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613
210 """Dummy version of delete_job"""
211 job_dir = simulation.job_dir
212 array_id = simulation.array_id
213 collect_id = simulation.collect_id
215 try:
216 con = get_connection(user=user, cluster=cluster)
218 con.run(f'scancel {array_id}')
219 con.run(f'scancel {collect_id}')
220 con.run(f'rm -rf {job_dir}')
221 except Exception as e: # skipcq: PYL-W0703
222 logging.error(e)
223 return {"message": "Job cancelation failed"}, 500
225 return {"message": "Job canceled"}, 200