Coverage for yaptide/batch/batch_methods.py: 15%

145 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import io 

2import json 

3import logging 

4import os 

5import tempfile 

6from datetime import datetime 

7from pathlib import Path 

8from zipfile import ZipFile 

9 

10import pymchelper 

11from fabric import Connection, Result 

12from paramiko import RSAKey 

13 

14from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT) 

15from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header) 

16from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel) 

17from yaptide.utils.enums import EntityState 

18from yaptide.utils.sim_utils import write_simulation_input_files 

19 

20 

21def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection: 

22 """Returns connection object to cluster""" 

23 pkey = RSAKey.from_private_key(io.StringIO(user.private_key)) 

24 pkey.load_certificate(user.cert) 

25 

26 con = Connection(host=f"{user.username}@{cluster.cluster_name}", 

27 connect_kwargs={ 

28 "pkey": pkey, 

29 "allow_agent": False, 

30 "look_for_keys": False 

31 }) 

32 return con 

33 

34 

35def submit_job(payload_dict: dict, files_dict: dict, user: KeycloakUserModel, cluster: ClusterModel, sim_id: int, 

36 update_key: str) -> dict: 

37 """Submits job to cluster""" 

38 utc_now = int(datetime.utcnow().timestamp() * 1e6) 

39 

40 if user.cert is None or user.private_key is None: 

41 return {"message": f"User {user.username} has no certificate or private key"} 

42 con = get_connection(user=user, cluster=cluster) 

43 

44 fabric_result: Result = con.run("echo $SCRATCH", hide=True) 

45 scratch = fabric_result.stdout.split()[0] 

46 logging.debug("Scratch directory: %s", scratch) 

47 

48 job_dir = f"{scratch}/yaptide_runs/{utc_now}" 

49 logging.debug("Job directory: %s", job_dir) 

50 

51 con.run(f"mkdir -p {job_dir}") 

52 with tempfile.TemporaryDirectory() as tmp_dir_path: 

53 logging.debug("Preparing simulation input in: %s", tmp_dir_path) 

54 zip_path = Path(tmp_dir_path) / "input.zip" 

55 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path)) 

56 logging.debug("Zipping simulation input to %s", zip_path) 

57 with ZipFile(zip_path, mode="w") as archive: 

58 for file in Path(tmp_dir_path).iterdir(): 

59 if file.name == "input.zip": 

60 continue 

61 archive.write(file, arcname=file.name) 

62 con.put(zip_path, job_dir) 

63 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir) 

64 

65 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py" 

66 RESULT_SENDER_SCRIPT = Path(__file__).parent.resolve() / "result_sender.py" 

67 

68 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir) 

69 con.put(WATCHER_SCRIPT, job_dir) 

70 logging.debug("Transfering result sender script %s to %s", RESULT_SENDER_SCRIPT, job_dir) 

71 con.put(RESULT_SENDER_SCRIPT, job_dir) 

72 

73 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict, 

74 job_dir=job_dir, 

75 sim_id=sim_id, 

76 update_key=update_key, 

77 con=con) 

78 

79 array_id = collect_id = None 

80 fabric_result: Result = con.run(f'sh {submit_file}', hide=True) 

81 submit_stdout = fabric_result.stdout 

82 submit_stderr = fabric_result.stderr 

83 for line in submit_stdout.split("\n"): 

84 if line.startswith("Job id"): 

85 try: 

86 array_id = int(line.split()[-1]) 

87 except (ValueError, IndexError): 

88 logging.error("Could not parse array id from line: %s", line) 

89 if line.startswith("Collect id"): 

90 try: 

91 collect_id = int(line.split()[-1]) 

92 except (ValueError, IndexError): 

93 logging.error("Could not parse collect id from line: %s", line) 

94 

95 if array_id is None or collect_id is None: 

96 logging.debug("Job submission failed") 

97 logging.debug("Sbatch stdout: %s", submit_stdout) 

98 logging.debug("Sbatch stderr: %s", submit_stderr) 

99 return {"message": "Job submission failed", "submit_stdout": submit_stdout, "sh_files": sh_files} 

100 return { 

101 "message": "Job submitted", 

102 "job_dir": job_dir, 

103 "array_id": array_id, 

104 "collect_id": collect_id, 

105 "submit_stdout": submit_stdout, 

106 "sh_files": sh_files 

107 } 

108 

109 

110def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str, 

111 con: Connection) -> tuple[str, dict]: 

112 """Prepares script files to run them on cluster""" 

113 submit_file = f'{job_dir}/yaptide_submitter.sh' 

114 array_file = f'{job_dir}/array_script.sh' 

115 collect_file = f'{job_dir}/collect_script.sh' 

116 

117 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options") 

118 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header") 

119 

120 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options") 

121 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header") 

122 

123 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "") 

124 

125 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options, 

126 collect_options=collect_options, 

127 root_dir=job_dir, 

128 n_tasks=str(payload_dict["ntasks"]), 

129 convertmc_version=pymchelper.__version__) 

130 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header, 

131 root_dir=job_dir, 

132 sim_id=sim_id, 

133 update_key=update_key, 

134 backend_url=backend_url) 

135 collect_script = COLLECT_BASH.format(collect_header=collect_header, 

136 root_dir=job_dir, 

137 clear_bdos="true", 

138 sim_id=sim_id, 

139 update_key=update_key, 

140 backend_url=backend_url) 

141 

142 con.run(f'echo \'{array_script}\' >> {array_file}') 

143 con.run(f'chmod +x {array_file}') 

144 con.run(f'echo \'{submit_script}\' >> {submit_file}') 

145 con.run(f'chmod +x {submit_file}') 

146 con.run(f'echo \'{collect_script}\' >> {collect_file}') 

147 con.run(f'chmod +x {collect_file}') 

148 

149 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script} 

150 

151 

152def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

153 """Get SLURM job status""" 

154 array_id = simulation.array_id 

155 collect_id = simulation.collect_id 

156 

157 con = get_connection(user=user, cluster=cluster) 

158 

159 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True) 

160 job_state = fabric_result.stdout.split()[-1].split()[0] 

161 

162 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

163 collect_state = fabric_result.stdout.split()[-1].split()[0] 

164 

165 if job_state == "FAILED" or collect_state == "FAILED": 

166 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

167 if collect_state == "COMPLETED": 

168 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

169 if collect_state == "RUNNING": 

170 logging.debug("Collect job is in RUNNING state") 

171 if job_state == "RUNNING": 

172 logging.debug("Main job is in RUNNING state") 

173 if collect_state == "PENDING": 

174 logging.debug("Collect job is in PENDING state") 

175 if job_state == "PENDING": 

176 logging.debug("Main job is in PENDING state") 

177 

178 return {"job_state": EntityState.RUNNING.value} 

179 

180 

181def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

182 """Returns simulation results""" 

183 job_dir = simulation.job_dir 

184 collect_id = simulation.collect_id 

185 

186 con = get_connection(user=user, cluster=cluster) 

187 

188 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

189 collect_state = fabric_result.stdout.split()[-1].split()[0] 

190 

191 if collect_state == "COMPLETED": 

192 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True) 

193 result_estimators = [] 

194 with tempfile.TemporaryDirectory() as tmp_dir_path: 

195 for filename in fabric_result.stdout.split(): 

196 file_path = Path(tmp_dir_path, filename) 

197 with open(file_path, "wb") as writer: 

198 con.get(f'{job_dir}/output/{filename}', writer) 

199 with open(file_path, "r") as json_file: 

200 est_dict = json.load(json_file) 

201 est_dict["name"] = filename.split('.')[0] 

202 result_estimators.append(est_dict) 

203 

204 return {"estimators": result_estimators} 

205 return {"message": "Results not available"} 

206 

207 

208def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel, 

209 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613 

210 """Dummy version of delete_job""" 

211 job_dir = simulation.job_dir 

212 array_id = simulation.array_id 

213 collect_id = simulation.collect_id 

214 

215 try: 

216 con = get_connection(user=user, cluster=cluster) 

217 

218 con.run(f'scancel {array_id}') 

219 con.run(f'scancel {collect_id}') 

220 con.run(f'rm -rf {job_dir}') 

221 except Exception as e: # skipcq: PYL-W0703 

222 logging.error(e) 

223 return {"message": "Job cancelation failed"}, 500 

224 

225 return {"message": "Job canceled"}, 200