Coverage for yaptide/batch/batch_methods.py: 15%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-03-31 19:18 +0000

1import io 

2import json 

3import logging 

4import os 

5import tempfile 

6import requests 

7from datetime import datetime 

8from pathlib import Path 

9from zipfile import ZipFile 

10import sqlalchemy as db 

11 

12import pymchelper 

13from fabric import Connection, Result 

14from paramiko import RSAKey 

15 

16from yaptide.batch.shieldhit_string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_SHIELDHIT_BASH, SUBMIT_SHIELDHIT) 

17from yaptide.batch.fluka_string_templates import (ARRAY_FLUKA_BASH, COLLECT_FLUKA_BASH, SUBMIT_FLUKA) 

18from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header) 

19from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel) 

20from yaptide.utils.enums import EntityState, SimulationType 

21from yaptide.utils.sim_utils import write_simulation_input_files 

22 

23from yaptide.admin.db_manage import TableTypes, connect_to_db 

24from yaptide.utils.helper_worker import celery_app 

25 

26 

27def get_user(db_con, metadata, userId): 

28 """Queries database for user""" 

29 users = metadata.tables[TableTypes.User.name] 

30 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name] 

31 stmt = db.select(users, 

32 keycloackUsers).select_from(users).join(keycloackUsers, 

33 KeycloakUserModel.id == UserModel.id).filter_by(id=userId) 

34 try: 

35 user: KeycloakUserModel = db_con.execute(stmt).first() 

36 except Exception: 

37 logging.error('Error getting user object wiht id: %s from database', str(userId)) 

38 return None 

39 return user 

40 

41 

42def get_cluster(db_con, metadata, clusterId): 

43 """Queries database for user""" 

44 clusters = metadata.tables[TableTypes.Cluster.name] 

45 stmt = db.select(clusters).filter_by(id=clusterId) 

46 try: 

47 cluster: ClusterModel = db_con.execute(stmt).first() 

48 except Exception: 

49 logging.error('Error getting cluster object with id: %s from database', str(clusterId)) 

50 return None 

51 return cluster 

52 

53 

54def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection: 

55 """Returns connection object to cluster""" 

56 pkey = RSAKey.from_private_key(io.StringIO(user.private_key)) 

57 pkey.load_certificate(user.cert) 

58 

59 con = Connection(host=f"{user.username}@{cluster.cluster_name}", 

60 connect_kwargs={ 

61 "pkey": pkey, 

62 "allow_agent": False, 

63 "look_for_keys": False 

64 }) 

65 return con 

66 

67 

68def post_update(dict_to_send): 

69 """For sending requests with information to flask""" 

70 flask_url = os.environ.get("BACKEND_INTERNAL_URL") 

71 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send) 

72 

73 

74@celery_app.task() 

75def submit_job( # skipcq: PY-R1000 

76 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str): 

77 """Submits job to cluster""" 

78 utc_now = int(datetime.utcnow().timestamp() * 1e6) 

79 try: 

80 db_con, metadata, _ = connect_to_db( 

81 ) # Connection to database and quering objects looks like that because celery task works outside flask context 

82 except Exception as e: 

83 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e)) 

84 

85 user = get_user(db_con=db_con, metadata=metadata, userId=userId) 

86 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId) 

87 

88 if user.cert is None or user.private_key is None: 

89 dict_to_send = { 

90 "sim_id": sim_id, 

91 "job_state": EntityState.FAILED.value, 

92 "update_key": update_key, 

93 "log": { 

94 "error": f"User {user.username} has no certificate or private key" 

95 } 

96 } 

97 post_update(dict_to_send) 

98 return 

99 

100 try: 

101 con = get_connection(user=user, cluster=cluster) 

102 fabric_result: Result = con.run("echo $SCRATCH", hide=True) 

103 except Exception as e: 

104 dict_to_send = { 

105 "sim_id": sim_id, 

106 "job_state": EntityState.FAILED.value, 

107 "update_key": update_key, 

108 "log": { 

109 "error": str(e) 

110 } 

111 } 

112 post_update(dict_to_send) 

113 return 

114 

115 scratch = fabric_result.stdout.split()[0] 

116 logging.debug("Scratch directory: %s", scratch) 

117 

118 job_dir = f"{scratch}/yaptide_runs/{utc_now}" 

119 logging.debug("Job directory: %s", job_dir) 

120 

121 try: 

122 con.run(f"mkdir -p {job_dir}") 

123 except Exception as e: 

124 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "update_key": update_key} 

125 post_update(dict_to_send) 

126 return 

127 with tempfile.TemporaryDirectory() as tmp_dir_path: 

128 logging.debug("Preparing simulation input in: %s", tmp_dir_path) 

129 zip_path = Path(tmp_dir_path) / "input.zip" 

130 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path)) 

131 logging.debug("Zipping simulation input to %s", zip_path) 

132 with ZipFile(zip_path, mode="w") as archive: 

133 for file in Path(tmp_dir_path).iterdir(): 

134 if file.name == "input.zip": 

135 continue 

136 archive.write(file, arcname=file.name) 

137 con.put(zip_path, job_dir) 

138 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir) 

139 

140 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py" 

141 SIMULATION_DATA_SENDER_SCRIPT = Path(__file__).parent.resolve() / "simulation_data_sender.py" 

142 

143 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir) 

144 con.put(WATCHER_SCRIPT, job_dir) 

145 logging.debug("Transfering result sender script %s to %s", SIMULATION_DATA_SENDER_SCRIPT, job_dir) 

146 con.put(SIMULATION_DATA_SENDER_SCRIPT, job_dir) 

147 

148 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict, 

149 job_dir=job_dir, 

150 sim_id=sim_id, 

151 update_key=update_key, 

152 con=con) 

153 

154 array_id = collect_id = None 

155 if not submit_file.startswith(job_dir): 

156 logging.error("Invalid submit file path: %s", submit_file) 

157 dict_to_send = { 

158 "sim_id": sim_id, 

159 "job_state": EntityState.FAILED.value, 

160 "update_key": update_key, 

161 "log": { 

162 "error": "Job submission failed due to invalid submit file path" 

163 } 

164 } 

165 post_update(dict_to_send) 

166 return 

167 fabric_result: Result = con.run(f'sh {submit_file}', hide=True) 

168 submit_stdout = fabric_result.stdout 

169 submit_stderr = fabric_result.stderr 

170 for line in submit_stdout.split("\n"): 

171 if line.startswith("Job id"): 

172 try: 

173 array_id = int(line.split()[-1]) 

174 except (ValueError, IndexError): 

175 logging.error("Could not parse array id from line: %s", line) 

176 if line.startswith("Collect id"): 

177 try: 

178 collect_id = int(line.split()[-1]) 

179 except (ValueError, IndexError): 

180 logging.error("Could not parse collect id from line: %s", line) 

181 

182 if array_id is None or collect_id is None: 

183 logging.debug("Job submission failed") 

184 logging.debug("Sbatch stdout: %s", submit_stdout) 

185 logging.debug("Sbatch stderr: %s", submit_stderr) 

186 dict_to_send = { 

187 "sim_id": sim_id, 

188 "job_state": EntityState.FAILED.value, 

189 "update_key": update_key, 

190 "log": { 

191 "message": "Job submission failed", 

192 "submit_stdout": submit_stdout, 

193 "sh_files": sh_files, 

194 "submit_stderr": submit_stderr 

195 } 

196 } 

197 post_update(dict_to_send) 

198 return 

199 

200 dict_to_send = { 

201 "sim_id": sim_id, 

202 "update_key": update_key, 

203 "job_dir": job_dir, 

204 "array_id": array_id, 

205 "collect_id": collect_id, 

206 "submit_stdout": submit_stdout, 

207 "sh_files": sh_files 

208 } 

209 post_update(dict_to_send) 

210 return 

211 

212 

213def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str, 

214 con: Connection) -> tuple[str, dict]: 

215 """Prepares script files to run them on cluster""" 

216 submit_file = f'{job_dir}/yaptide_submitter.sh' 

217 array_file = f'{job_dir}/array_script.sh' 

218 collect_file = f'{job_dir}/collect_script.sh' 

219 

220 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options") 

221 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header") 

222 

223 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options") 

224 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header") 

225 

226 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "") 

227 

228 if payload_dict['sim_type'] == SimulationType.FLUKA.value: 

229 submit_template = SUBMIT_FLUKA 

230 array_template = ARRAY_FLUKA_BASH 

231 collect_template = COLLECT_FLUKA_BASH 

232 elif payload_dict['sim_type'] == SimulationType.SHIELDHIT.value: 

233 submit_template = SUBMIT_SHIELDHIT 

234 array_template = ARRAY_SHIELDHIT_BASH 

235 collect_template = COLLECT_SHIELDHIT_BASH 

236 else: 

237 # Ready for future simulators 

238 submit_template = "" 

239 array_template = "" 

240 collect_template = "" 

241 

242 submit_script = submit_template.format(array_options=array_options, 

243 collect_options=collect_options, 

244 root_dir=job_dir, 

245 n_tasks=str(payload_dict["ntasks"]), 

246 convertmc_version=pymchelper.__version__) 

247 array_script = array_template.format(array_header=array_header, 

248 root_dir=job_dir, 

249 sim_id=sim_id, 

250 update_key=update_key, 

251 backend_url=backend_url) 

252 collect_script = collect_template.format(collect_header=collect_header, 

253 root_dir=job_dir, 

254 remove_output_from_workspace="true", 

255 sim_id=sim_id, 

256 update_key=update_key, 

257 backend_url=backend_url) 

258 

259 con.run(f'echo \'{array_script}\' >> {array_file}') 

260 con.run(f'chmod +x {array_file}') 

261 con.run(f'echo \'{submit_script}\' >> {submit_file}') 

262 con.run(f'chmod +x {submit_file}') 

263 con.run(f'echo \'{collect_script}\' >> {collect_file}') 

264 con.run(f'chmod +x {collect_file}') 

265 

266 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script} 

267 

268 

269def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

270 """Get SLURM job status""" 

271 array_id = simulation.array_id 

272 collect_id = simulation.collect_id 

273 

274 con = get_connection(user=user, cluster=cluster) 

275 

276 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True) 

277 job_state = fabric_result.stdout.split()[-1].split()[0] 

278 

279 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

280 collect_state = fabric_result.stdout.split()[-1].split()[0] 

281 

282 if job_state == "FAILED" or collect_state == "FAILED": 

283 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

284 if collect_state == "COMPLETED": 

285 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

286 if collect_state == "RUNNING": 

287 logging.debug("Collect job is in RUNNING state") 

288 return {"job_state": EntityState.MERGING_RUNNING.value} 

289 if job_state == "COMPLETED" and collect_state == "PENDING": 

290 logging.debug("Collect job is in PENDING state") 

291 return {"job_state": EntityState.MERGING_QUEUED.value} 

292 if job_state == "RUNNING": 

293 logging.debug("Main job is in RUNNING state") 

294 if job_state == "PENDING": 

295 logging.debug("Main job is in PENDING state") 

296 

297 return {"job_state": EntityState.RUNNING.value} 

298 

299 

300def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

301 """Returns simulation results""" 

302 job_dir = simulation.job_dir 

303 collect_id = simulation.collect_id 

304 

305 con = get_connection(user=user, cluster=cluster) 

306 

307 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

308 collect_state = fabric_result.stdout.split()[-1].split()[0] 

309 

310 if collect_state == "COMPLETED": 

311 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True) 

312 result_estimators = [] 

313 with tempfile.TemporaryDirectory() as tmp_dir_path: 

314 for filename in fabric_result.stdout.split(): 

315 file_path = Path(tmp_dir_path, filename) 

316 with open(file_path, "wb") as writer: 

317 con.get(f'{job_dir}/output/{filename}', writer) 

318 with open(file_path, "r") as json_file: 

319 est_dict = json.load(json_file) 

320 est_dict["name"] = filename.split('.')[0] 

321 result_estimators.append(est_dict) 

322 

323 return {"estimators": result_estimators} 

324 return {"message": "Results not available"} 

325 

326 

327def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel, 

328 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613 

329 """Dummy version of delete_job""" 

330 job_dir = simulation.job_dir 

331 array_id = simulation.array_id 

332 collect_id = simulation.collect_id 

333 

334 try: 

335 con = get_connection(user=user, cluster=cluster) 

336 

337 con.run(f'scancel {array_id}') 

338 con.run(f'scancel {collect_id}') 

339 con.run(f'rm -rf {job_dir}') 

340 except Exception as e: # skipcq: PYL-W0703 

341 logging.error(e) 

342 return {"message": "Job cancelation failed"}, 500 

343 

344 return {"message": "Job canceled"}, 200