Coverage for yaptide/batch/batch_methods.py: 16%

201 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import io 

2import json 

3import logging 

4import os 

5import tempfile 

6import requests 

7from datetime import datetime 

8from pathlib import Path 

9from zipfile import ZipFile 

10import sqlalchemy as db 

11 

12import pymchelper 

13from fabric import Connection, Result 

14from paramiko import RSAKey 

15 

16from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT) 

17from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header) 

18from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel) 

19from yaptide.utils.enums import EntityState 

20from yaptide.utils.sim_utils import write_simulation_input_files 

21 

22from yaptide.admin.db_manage import TableTypes, connect_to_db 

23from yaptide.utils.helper_worker import celery_app 

24 

25 

26def get_user(db_con, metadata, userId): 

27 """Queries database for user""" 

28 users = metadata.tables[TableTypes.User.name] 

29 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name] 

30 stmt = db.select(users, 

31 keycloackUsers).select_from(users).join(keycloackUsers, 

32 KeycloakUserModel.id == UserModel.id).filter_by(id=userId) 

33 try: 

34 user: KeycloakUserModel = db_con.execute(stmt).first() 

35 except Exception: 

36 logging.error('Error getting user object wiht id: %s from database', str(userId)) 

37 return None 

38 return user 

39 

40 

41def get_cluster(db_con, metadata, clusterId): 

42 """Queries database for user""" 

43 clusters = metadata.tables[TableTypes.Cluster.name] 

44 stmt = db.select(clusters).filter_by(id=clusterId) 

45 try: 

46 cluster: ClusterModel = db_con.execute(stmt).first() 

47 except Exception: 

48 logging.error('Error getting cluster object with id: %s from database', str(clusterId)) 

49 return None 

50 return cluster 

51 

52 

53def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection: 

54 """Returns connection object to cluster""" 

55 pkey = RSAKey.from_private_key(io.StringIO(user.private_key)) 

56 pkey.load_certificate(user.cert) 

57 

58 con = Connection(host=f"{user.username}@{cluster.cluster_name}", 

59 connect_kwargs={ 

60 "pkey": pkey, 

61 "allow_agent": False, 

62 "look_for_keys": False 

63 }) 

64 return con 

65 

66 

67def post_update(dict_to_send): 

68 """For sending requests with information to flask""" 

69 flask_url = os.environ.get("BACKEND_INTERNAL_URL") 

70 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send) 

71 

72 

73@celery_app.task() 

74def submit_job( # skipcq: PY-R1000 

75 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str): 

76 """Submits job to cluster""" 

77 utc_now = int(datetime.utcnow().timestamp() * 1e6) 

78 try: 

79 db_con, metadata, _ = connect_to_db( 

80 ) # Connection to database and quering objects looks like that because celery task works outside flask context 

81 except Exception as e: 

82 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e)) 

83 

84 user = get_user(db_con=db_con, metadata=metadata, userId=userId) 

85 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId) 

86 

87 if user.cert is None or user.private_key is None: 

88 dict_to_send = { 

89 "sim_id": sim_id, 

90 "job_state": EntityState.FAILED.value, 

91 "update_key": update_key, 

92 "log": { 

93 "error": f"User {user.username} has no certificate or private key" 

94 } 

95 } 

96 post_update(dict_to_send) 

97 return 

98 

99 try: 

100 con = get_connection(user=user, cluster=cluster) 

101 fabric_result: Result = con.run("echo $SCRATCH", hide=True) 

102 except Exception as e: 

103 dict_to_send = { 

104 "sim_id": sim_id, 

105 "job_state": EntityState.FAILED.value, 

106 "update_key": update_key, 

107 "log": { 

108 "error": str(e) 

109 } 

110 } 

111 post_update(dict_to_send) 

112 return 

113 

114 scratch = fabric_result.stdout.split()[0] 

115 logging.debug("Scratch directory: %s", scratch) 

116 

117 job_dir = f"{scratch}/yaptide_runs/{utc_now}" 

118 logging.debug("Job directory: %s", job_dir) 

119 

120 try: 

121 con.run(f"mkdir -p {job_dir}") 

122 except Exception as e: 

123 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "update_key": update_key} 

124 post_update(dict_to_send) 

125 return 

126 with tempfile.TemporaryDirectory() as tmp_dir_path: 

127 logging.debug("Preparing simulation input in: %s", tmp_dir_path) 

128 zip_path = Path(tmp_dir_path) / "input.zip" 

129 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path)) 

130 logging.debug("Zipping simulation input to %s", zip_path) 

131 with ZipFile(zip_path, mode="w") as archive: 

132 for file in Path(tmp_dir_path).iterdir(): 

133 if file.name == "input.zip": 

134 continue 

135 archive.write(file, arcname=file.name) 

136 con.put(zip_path, job_dir) 

137 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir) 

138 

139 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py" 

140 SIMULATION_DATA_SENDER_SCRIPT = Path(__file__).parent.resolve() / "simulation_data_sender.py" 

141 

142 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir) 

143 con.put(WATCHER_SCRIPT, job_dir) 

144 logging.debug("Transfering result sender script %s to %s", SIMULATION_DATA_SENDER_SCRIPT, job_dir) 

145 con.put(SIMULATION_DATA_SENDER_SCRIPT, job_dir) 

146 

147 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict, 

148 job_dir=job_dir, 

149 sim_id=sim_id, 

150 update_key=update_key, 

151 con=con) 

152 

153 array_id = collect_id = None 

154 if not submit_file.startswith(job_dir): 

155 logging.error("Invalid submit file path: %s", submit_file) 

156 dict_to_send = { 

157 "sim_id": sim_id, 

158 "job_state": EntityState.FAILED.value, 

159 "update_key": update_key, 

160 "log": { 

161 "error": "Job submission failed due to invalid submit file path" 

162 } 

163 } 

164 post_update(dict_to_send) 

165 return 

166 fabric_result: Result = con.run(f'sh {submit_file}', hide=True) 

167 submit_stdout = fabric_result.stdout 

168 submit_stderr = fabric_result.stderr 

169 for line in submit_stdout.split("\n"): 

170 if line.startswith("Job id"): 

171 try: 

172 array_id = int(line.split()[-1]) 

173 except (ValueError, IndexError): 

174 logging.error("Could not parse array id from line: %s", line) 

175 if line.startswith("Collect id"): 

176 try: 

177 collect_id = int(line.split()[-1]) 

178 except (ValueError, IndexError): 

179 logging.error("Could not parse collect id from line: %s", line) 

180 

181 if array_id is None or collect_id is None: 

182 logging.debug("Job submission failed") 

183 logging.debug("Sbatch stdout: %s", submit_stdout) 

184 logging.debug("Sbatch stderr: %s", submit_stderr) 

185 dict_to_send = { 

186 "sim_id": sim_id, 

187 "job_state": EntityState.FAILED.value, 

188 "update_key": update_key, 

189 "log": { 

190 "message": "Job submission failed", 

191 "submit_stdout": submit_stdout, 

192 "sh_files": sh_files, 

193 "submit_stderr": submit_stderr 

194 } 

195 } 

196 post_update(dict_to_send) 

197 return 

198 

199 dict_to_send = { 

200 "sim_id": sim_id, 

201 "update_key": update_key, 

202 "job_dir": job_dir, 

203 "array_id": array_id, 

204 "collect_id": collect_id, 

205 "submit_stdout": submit_stdout, 

206 "sh_files": sh_files 

207 } 

208 post_update(dict_to_send) 

209 return 

210 

211 

212def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str, 

213 con: Connection) -> tuple[str, dict]: 

214 """Prepares script files to run them on cluster""" 

215 submit_file = f'{job_dir}/yaptide_submitter.sh' 

216 array_file = f'{job_dir}/array_script.sh' 

217 collect_file = f'{job_dir}/collect_script.sh' 

218 

219 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options") 

220 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header") 

221 

222 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options") 

223 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header") 

224 

225 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "") 

226 

227 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options, 

228 collect_options=collect_options, 

229 root_dir=job_dir, 

230 n_tasks=str(payload_dict["ntasks"]), 

231 convertmc_version=pymchelper.__version__) 

232 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header, 

233 root_dir=job_dir, 

234 sim_id=sim_id, 

235 update_key=update_key, 

236 backend_url=backend_url) 

237 collect_script = COLLECT_BASH.format(collect_header=collect_header, 

238 root_dir=job_dir, 

239 clear_bdos="true", 

240 sim_id=sim_id, 

241 update_key=update_key, 

242 backend_url=backend_url) 

243 

244 con.run(f'echo \'{array_script}\' >> {array_file}') 

245 con.run(f'chmod +x {array_file}') 

246 con.run(f'echo \'{submit_script}\' >> {submit_file}') 

247 con.run(f'chmod +x {submit_file}') 

248 con.run(f'echo \'{collect_script}\' >> {collect_file}') 

249 con.run(f'chmod +x {collect_file}') 

250 

251 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script} 

252 

253 

254def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

255 """Get SLURM job status""" 

256 array_id = simulation.array_id 

257 collect_id = simulation.collect_id 

258 

259 con = get_connection(user=user, cluster=cluster) 

260 

261 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True) 

262 job_state = fabric_result.stdout.split()[-1].split()[0] 

263 

264 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

265 collect_state = fabric_result.stdout.split()[-1].split()[0] 

266 

267 if job_state == "FAILED" or collect_state == "FAILED": 

268 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

269 if collect_state == "COMPLETED": 

270 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

271 if collect_state == "RUNNING": 

272 logging.debug("Collect job is in RUNNING state") 

273 return {"job_state": EntityState.MERGING_RUNNING.value} 

274 if job_state == "COMPLETED" and collect_state == "PENDING": 

275 logging.debug("Collect job is in PENDING state") 

276 return {"job_state": EntityState.MERGING_QUEUED.value} 

277 if job_state == "RUNNING": 

278 logging.debug("Main job is in RUNNING state") 

279 if job_state == "PENDING": 

280 logging.debug("Main job is in PENDING state") 

281 

282 return {"job_state": EntityState.RUNNING.value} 

283 

284 

285def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

286 """Returns simulation results""" 

287 job_dir = simulation.job_dir 

288 collect_id = simulation.collect_id 

289 

290 con = get_connection(user=user, cluster=cluster) 

291 

292 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

293 collect_state = fabric_result.stdout.split()[-1].split()[0] 

294 

295 if collect_state == "COMPLETED": 

296 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True) 

297 result_estimators = [] 

298 with tempfile.TemporaryDirectory() as tmp_dir_path: 

299 for filename in fabric_result.stdout.split(): 

300 file_path = Path(tmp_dir_path, filename) 

301 with open(file_path, "wb") as writer: 

302 con.get(f'{job_dir}/output/{filename}', writer) 

303 with open(file_path, "r") as json_file: 

304 est_dict = json.load(json_file) 

305 est_dict["name"] = filename.split('.')[0] 

306 result_estimators.append(est_dict) 

307 

308 return {"estimators": result_estimators} 

309 return {"message": "Results not available"} 

310 

311 

312def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel, 

313 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613 

314 """Dummy version of delete_job""" 

315 job_dir = simulation.job_dir 

316 array_id = simulation.array_id 

317 collect_id = simulation.collect_id 

318 

319 try: 

320 con = get_connection(user=user, cluster=cluster) 

321 

322 con.run(f'scancel {array_id}') 

323 con.run(f'scancel {collect_id}') 

324 con.run(f'rm -rf {job_dir}') 

325 except Exception as e: # skipcq: PYL-W0703 

326 logging.error(e) 

327 return {"message": "Job cancelation failed"}, 500 

328 

329 return {"message": "Job canceled"}, 200