Coverage for yaptide/batch/batch_methods.py: 15%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import io 

2import json 

3import logging 

4import os 

5import tempfile 

6import requests 

7from datetime import datetime 

8from pathlib import Path 

9from zipfile import ZipFile 

10import sqlalchemy as db 

11 

12import pymchelper 

13from fabric import Connection, Result 

14from paramiko import RSAKey 

15 

16from yaptide.batch.string_templates import (ARRAY_SHIELDHIT_BASH, COLLECT_BASH, SUBMIT_SHIELDHIT) 

17from yaptide.batch.utils.utils import (convert_dict_to_sbatch_options, extract_sbatch_header) 

18from yaptide.persistence.models import (BatchSimulationModel, ClusterModel, KeycloakUserModel, UserModel) 

19from yaptide.utils.enums import EntityState 

20from yaptide.utils.sim_utils import write_simulation_input_files 

21 

22from yaptide.admin.db_manage import TableTypes, connect_to_db 

23from yaptide.utils.helper_worker import celery_app 

24 

25 

26def get_user(db_con, metadata, userId): 

27 """Queries database for user""" 

28 users = metadata.tables[TableTypes.User.name] 

29 keycloackUsers = metadata.tables[TableTypes.KeycloakUser.name] 

30 stmt = db.select(users, 

31 keycloackUsers).select_from(users).join(keycloackUsers, 

32 KeycloakUserModel.id == UserModel.id).filter_by(id=userId) 

33 try: 

34 user: KeycloakUserModel = db_con.execute(stmt).first() 

35 except Exception: 

36 logging.error('Error getting user object wiht id: %s from database', str(userId)) 

37 return None 

38 return user 

39 

40 

41def get_cluster(db_con, metadata, clusterId): 

42 """Queries database for user""" 

43 clusters = metadata.tables[TableTypes.Cluster.name] 

44 stmt = db.select(clusters).filter_by(id=clusterId) 

45 try: 

46 cluster: ClusterModel = db_con.execute(stmt).first() 

47 except Exception: 

48 logging.error('Error getting cluster object with id: %s from database', str(clusterId)) 

49 return None 

50 return cluster 

51 

52 

53def get_connection(user: KeycloakUserModel, cluster: ClusterModel) -> Connection: 

54 """Returns connection object to cluster""" 

55 pkey = RSAKey.from_private_key(io.StringIO(user.private_key)) 

56 pkey.load_certificate(user.cert) 

57 

58 con = Connection(host=f"{user.username}@{cluster.cluster_name}", 

59 connect_kwargs={ 

60 "pkey": pkey, 

61 "allow_agent": False, 

62 "look_for_keys": False 

63 }) 

64 return con 

65 

66 

67def post_update(dict_to_send): 

68 """For sending requests with information to flask""" 

69 flask_url = os.environ.get("BACKEND_INTERNAL_URL") 

70 return requests.Session().post(url=f"{flask_url}/jobs", json=dict_to_send) 

71 

72 

73@celery_app.task() 

74def submit_job( # skipcq: PY-R1000 

75 payload_dict: dict, files_dict: dict, userId: int, clusterId: int, sim_id: int, update_key: str): 

76 """Submits job to cluster""" 

77 utc_now = int(datetime.utcnow().timestamp() * 1e6) 

78 try: 

79 db_con, metadata, _ = connect_to_db( 

80 ) # Connection to database and quering objects looks like that because celery task works outside flask context 

81 except Exception as e: 

82 logging.error('Async worker couldn\'t connect to db. Error message:"%s"', str(e)) 

83 

84 user = get_user(db_con=db_con, metadata=metadata, userId=userId) 

85 cluster = get_cluster(db_con=db_con, metadata=metadata, clusterId=clusterId) 

86 

87 if user.cert is None or user.private_key is None: 

88 dict_to_send = { 

89 "sim_id": sim_id, 

90 "job_state": EntityState.FAILED.value, 

91 "log": { 

92 "error": f"User {user.username} has no certificate or private key" 

93 } 

94 } 

95 post_update(dict_to_send) 

96 return 

97 

98 try: 

99 con = get_connection(user=user, cluster=cluster) 

100 fabric_result: Result = con.run("echo $SCRATCH", hide=True) 

101 except Exception as e: 

102 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value, "log": {"error": str(e)}} 

103 post_update(dict_to_send) 

104 return 

105 

106 scratch = fabric_result.stdout.split()[0] 

107 logging.debug("Scratch directory: %s", scratch) 

108 

109 job_dir = f"{scratch}/yaptide_runs/{utc_now}" 

110 logging.debug("Job directory: %s", job_dir) 

111 

112 try: 

113 con.run(f"mkdir -p {job_dir}") 

114 except Exception as e: 

115 dict_to_send = {"sim_id": sim_id, "job_state": EntityState.FAILED.value} 

116 post_update(dict_to_send) 

117 return 

118 with tempfile.TemporaryDirectory() as tmp_dir_path: 

119 logging.debug("Preparing simulation input in: %s", tmp_dir_path) 

120 zip_path = Path(tmp_dir_path) / "input.zip" 

121 write_simulation_input_files(files_dict=files_dict, output_dir=Path(tmp_dir_path)) 

122 logging.debug("Zipping simulation input to %s", zip_path) 

123 with ZipFile(zip_path, mode="w") as archive: 

124 for file in Path(tmp_dir_path).iterdir(): 

125 if file.name == "input.zip": 

126 continue 

127 archive.write(file, arcname=file.name) 

128 con.put(zip_path, job_dir) 

129 logging.debug("Transfering simulation input %s to %s", zip_path, job_dir) 

130 

131 WATCHER_SCRIPT = Path(__file__).parent.resolve() / "watcher.py" 

132 RESULT_SENDER_SCRIPT = Path(__file__).parent.resolve() / "result_sender.py" 

133 

134 logging.debug("Transfering watcher script %s to %s", WATCHER_SCRIPT, job_dir) 

135 con.put(WATCHER_SCRIPT, job_dir) 

136 logging.debug("Transfering result sender script %s to %s", RESULT_SENDER_SCRIPT, job_dir) 

137 con.put(RESULT_SENDER_SCRIPT, job_dir) 

138 

139 submit_file, sh_files = prepare_script_files(payload_dict=payload_dict, 

140 job_dir=job_dir, 

141 sim_id=sim_id, 

142 update_key=update_key, 

143 con=con) 

144 

145 array_id = collect_id = None 

146 if not submit_file.startswith(job_dir): 

147 logging.error("Invalid submit file path: %s", submit_file) 

148 dict_to_send = { 

149 "sim_id": sim_id, 

150 "job_state": EntityState.FAILED.value, 

151 "log": { 

152 "error": "Job submission failed due to invalid submit file path" 

153 } 

154 } 

155 post_update(dict_to_send) 

156 return 

157 fabric_result: Result = con.run(f'sh {submit_file}', hide=True) 

158 submit_stdout = fabric_result.stdout 

159 submit_stderr = fabric_result.stderr 

160 for line in submit_stdout.split("\n"): 

161 if line.startswith("Job id"): 

162 try: 

163 array_id = int(line.split()[-1]) 

164 except (ValueError, IndexError): 

165 logging.error("Could not parse array id from line: %s", line) 

166 if line.startswith("Collect id"): 

167 try: 

168 collect_id = int(line.split()[-1]) 

169 except (ValueError, IndexError): 

170 logging.error("Could not parse collect id from line: %s", line) 

171 

172 if array_id is None or collect_id is None: 

173 logging.debug("Job submission failed") 

174 logging.debug("Sbatch stdout: %s", submit_stdout) 

175 logging.debug("Sbatch stderr: %s", submit_stderr) 

176 dict_to_send = { 

177 "sim_id": sim_id, 

178 "job_state": EntityState.FAILED.value, 

179 "log": { 

180 "message": "Job submission failed", 

181 "submit_stdout": submit_stdout, 

182 "sh_files": sh_files, 

183 "submit_stderr": submit_stderr 

184 } 

185 } 

186 post_update(dict_to_send) 

187 return 

188 

189 dict_to_send = { 

190 "sim_id": sim_id, 

191 "job_dir": job_dir, 

192 "array_id": array_id, 

193 "collect_id": collect_id, 

194 "submit_stdout": submit_stdout, 

195 "sh_files": sh_files 

196 } 

197 post_update(dict_to_send) 

198 return 

199 

200 

201def prepare_script_files(payload_dict: dict, job_dir: str, sim_id: int, update_key: str, 

202 con: Connection) -> tuple[str, dict]: 

203 """Prepares script files to run them on cluster""" 

204 submit_file = f'{job_dir}/yaptide_submitter.sh' 

205 array_file = f'{job_dir}/array_script.sh' 

206 collect_file = f'{job_dir}/collect_script.sh' 

207 

208 array_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="array_options") 

209 array_header = extract_sbatch_header(payload_dict=payload_dict, target_key="array_header") 

210 

211 collect_options = convert_dict_to_sbatch_options(payload_dict=payload_dict, target_key="collect_options") 

212 collect_header = extract_sbatch_header(payload_dict=payload_dict, target_key="collect_header") 

213 

214 backend_url = os.environ.get("BACKEND_EXTERNAL_URL", "") 

215 

216 submit_script = SUBMIT_SHIELDHIT.format(array_options=array_options, 

217 collect_options=collect_options, 

218 root_dir=job_dir, 

219 n_tasks=str(payload_dict["ntasks"]), 

220 convertmc_version=pymchelper.__version__) 

221 array_script = ARRAY_SHIELDHIT_BASH.format(array_header=array_header, 

222 root_dir=job_dir, 

223 sim_id=sim_id, 

224 update_key=update_key, 

225 backend_url=backend_url) 

226 collect_script = COLLECT_BASH.format(collect_header=collect_header, 

227 root_dir=job_dir, 

228 clear_bdos="true", 

229 sim_id=sim_id, 

230 update_key=update_key, 

231 backend_url=backend_url) 

232 

233 con.run(f'echo \'{array_script}\' >> {array_file}') 

234 con.run(f'chmod +x {array_file}') 

235 con.run(f'echo \'{submit_script}\' >> {submit_file}') 

236 con.run(f'chmod +x {submit_file}') 

237 con.run(f'echo \'{collect_script}\' >> {collect_file}') 

238 con.run(f'chmod +x {collect_file}') 

239 

240 return submit_file, {"submit": submit_script, "array": array_script, "collect": collect_script} 

241 

242 

243def get_job_status(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

244 """Get SLURM job status""" 

245 array_id = simulation.array_id 

246 collect_id = simulation.collect_id 

247 

248 con = get_connection(user=user, cluster=cluster) 

249 

250 fabric_result: Result = con.run(f'sacct -j {array_id} --format State', hide=True) 

251 job_state = fabric_result.stdout.split()[-1].split()[0] 

252 

253 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

254 collect_state = fabric_result.stdout.split()[-1].split()[0] 

255 

256 if job_state == "FAILED" or collect_state == "FAILED": 

257 return {"job_state": EntityState.FAILED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

258 if collect_state == "COMPLETED": 

259 return {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

260 if collect_state == "RUNNING": 

261 logging.debug("Collect job is in RUNNING state") 

262 if job_state == "RUNNING": 

263 logging.debug("Main job is in RUNNING state") 

264 if collect_state == "PENDING": 

265 logging.debug("Collect job is in PENDING state") 

266 if job_state == "PENDING": 

267 logging.debug("Main job is in PENDING state") 

268 

269 return {"job_state": EntityState.RUNNING.value} 

270 

271 

272def get_job_results(simulation: BatchSimulationModel, user: KeycloakUserModel, cluster: ClusterModel) -> dict: 

273 """Returns simulation results""" 

274 job_dir = simulation.job_dir 

275 collect_id = simulation.collect_id 

276 

277 con = get_connection(user=user, cluster=cluster) 

278 

279 fabric_result: Result = con.run(f'sacct -j {collect_id} --format State', hide=True) 

280 collect_state = fabric_result.stdout.split()[-1].split()[0] 

281 

282 if collect_state == "COMPLETED": 

283 fabric_result: Result = con.run(f'ls -f {job_dir}/output | grep .json', hide=True) 

284 result_estimators = [] 

285 with tempfile.TemporaryDirectory() as tmp_dir_path: 

286 for filename in fabric_result.stdout.split(): 

287 file_path = Path(tmp_dir_path, filename) 

288 with open(file_path, "wb") as writer: 

289 con.get(f'{job_dir}/output/{filename}', writer) 

290 with open(file_path, "r") as json_file: 

291 est_dict = json.load(json_file) 

292 est_dict["name"] = filename.split('.')[0] 

293 result_estimators.append(est_dict) 

294 

295 return {"estimators": result_estimators} 

296 return {"message": "Results not available"} 

297 

298 

299def delete_job(simulation: BatchSimulationModel, user: KeycloakUserModel, 

300 cluster: ClusterModel) -> tuple[dict, int]: # skipcq: PYL-W0613 

301 """Dummy version of delete_job""" 

302 job_dir = simulation.job_dir 

303 array_id = simulation.array_id 

304 collect_id = simulation.collect_id 

305 

306 try: 

307 con = get_connection(user=user, cluster=cluster) 

308 

309 con.run(f'scancel {array_id}') 

310 con.run(f'scancel {collect_id}') 

311 con.run(f'rm -rf {job_dir}') 

312 except Exception as e: # skipcq: PYL-W0703 

313 logging.error(e) 

314 return {"message": "Job cancelation failed"}, 500 

315 

316 return {"message": "Job canceled"}, 200