Coverage for yaptide/routes/common_sim_routes.py: 61%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4from typing import List 

5 

6from flask import request, current_app as app 

7from flask_restful import Resource 

8from marshmallow import Schema, fields 

9 

10from yaptide.persistence.db_methods import ( 

11 add_object_to_db, fetch_estimator_by_sim_id_and_est_name, fetch_estimator_by_sim_id_and_file_name, 

12 fetch_estimator_id_by_sim_id_and_est_name, fetch_estimators_by_sim_id, fetch_input_by_sim_id, 

13 fetch_logfiles_by_sim_id, fetch_page_by_est_id_and_page_number, fetch_pages_by_est_id_and_page_numbers, 

14 fetch_pages_by_estimator_id, fetch_simulation_by_job_id, fetch_simulation_by_sim_id, fetch_simulation_id_by_job_id, 

15 fetch_tasks_by_sim_id, make_commit_to_db, update_simulation_state) 

16from yaptide.persistence.models import (EstimatorModel, LogfilesModel, PageModel, UserModel) 

17from yaptide.routes.utils.decorators import requires_auth 

18from yaptide.routes.utils.response_templates import yaptide_response 

19from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist 

20from yaptide.routes.utils.tokens import decode_auth_token 

21from yaptide.utils.enums import EntityState, InputType 

22 

23 

24class JobsResource(Resource): 

25 """Class responsible for managing common jobs""" 

26 

27 class APIParametersSchema(Schema): 

28 """Class specifies API parameters for GET and DELETE request""" 

29 

30 job_id = fields.String() 

31 

32 @staticmethod 

33 @requires_auth() 

34 def get(user: UserModel): 

35 """Method returning info about job""" 

36 schema = JobsResource.APIParametersSchema() 

37 errors: dict[str, list[str]] = schema.validate(request.args) 

38 if errors: 

39 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

40 param_dict: dict = schema.load(request.args) 

41 

42 # get job_id from request parameters and check if user owns this job 

43 job_id = param_dict['job_id'] 

44 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

45 if not is_owned: 

46 return yaptide_response(message=error_message, code=res_code) 

47 

48 simulation = fetch_simulation_by_job_id(job_id=job_id) 

49 if simulation.job_state == EntityState.UNKNOWN.value: 

50 return yaptide_response(message="Job state is unknown", 

51 code=200, 

52 content={"job_state": simulation.job_state}) 

53 

54 tasks = fetch_tasks_by_sim_id(sim_id=simulation.id) 

55 

56 job_tasks_status = [task.get_status_dict() for task in tasks] 

57 

58 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, 

59 EntityState.MERGING_QUEUED.value, EntityState.MERGING_RUNNING.value): 

60 return yaptide_response(message=f"Job state: {simulation.job_state}", 

61 code=200, 

62 content={ 

63 "job_state": simulation.job_state, 

64 "job_tasks_status": job_tasks_status, 

65 }) 

66 

67 job_info = {"job_state": simulation.job_state} 

68 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

69 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

70 job_info["job_state"] = EntityState.PENDING.value 

71 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

72 job_info["job_state"] = EntityState.FAILED.value 

73 elif status_counter[EntityState.RUNNING.value] > 0: 

74 job_info["job_state"] = EntityState.RUNNING.value 

75 elif job_id.endswith("BATCH") and status_counter[EntityState.COMPLETED.value] == len(job_tasks_status): 

76 job_info["job_state"] = EntityState.MERGING_QUEUED.value 

77 

78 update_simulation_state(simulation=simulation, update_dict=job_info) 

79 

80 job_info["job_tasks_status"] = job_tasks_status 

81 

82 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

83 

84 @staticmethod 

85 def post(): 

86 """Handles requests for updating simulation informations in db""" 

87 payload_dict: dict = request.get_json(force=True) 

88 sim_id: int = payload_dict["sim_id"] 

89 app.logger.info(f"sim_id {sim_id}") 

90 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

91 

92 if not simulation: 

93 app.logger.info(f"sim_id {sim_id} simulation not found ") 

94 return yaptide_response(message=f"Simulation {sim_id} does not exist", code=501) 

95 

96 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

97 if decoded_token != sim_id: 

98 return yaptide_response(message="Invalid update key", code=400) 

99 

100 update_simulation_state(simulation, payload_dict) 

101 if payload_dict.get("log"): 

102 logfiles = LogfilesModel(simulation_id=simulation.id) 

103 logfiles.data = payload_dict["log"] 

104 add_object_to_db(logfiles) 

105 

106 return yaptide_response(message="Task updated", code=202) 

107 

108 

109def get_single_estimator(sim_id: int, estimator_name: str): 

110 """Retrieve a single estimator by simulation ID and estimator name""" 

111 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_name) 

112 

113 if not estimator: 

114 return yaptide_response(message="Estimator not found", code=404) 

115 

116 pages = fetch_pages_by_estimator_id(est_id=estimator.id) 

117 estimator_dict = {"metadata": estimator.data, "name": estimator.name, "pages": [page.data for page in pages]} 

118 return yaptide_response(message=f"Estimator '{estimator_name}' for simulation: {sim_id}", 

119 code=200, 

120 content=estimator_dict) 

121 

122 

123def get_all_estimators(sim_id: int): 

124 """Retrieve all estimators for a given simulation ID""" 

125 estimators = fetch_estimators_by_sim_id(sim_id=sim_id) 

126 if len(estimators) == 0: 

127 return yaptide_response(message="Results are unavailable", code=404) 

128 

129 logging.debug("Returning results from database") 

130 result_estimators = [] 

131 for estimator in estimators: 

132 estimator_dict = { 

133 "metadata": estimator.data, 

134 "name": estimator.name, 

135 "pages": [page.data for page in estimator.pages] 

136 } 

137 result_estimators.append(estimator_dict) 

138 return yaptide_response(message=f"Results for simulation: {sim_id}", 

139 code=200, 

140 content={"estimators": result_estimators}) 

141 

142 

143def prepare_create_or_update_estimator_in_db(sim_id: int, name: str, estimator_dict: dict): 

144 """Prepares an estimator object for insertion or update without committing to the database""" 

145 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"]) 

146 if not estimator: 

147 estimator = EstimatorModel(name=name, file_name=estimator_dict["name"], simulation_id=sim_id) 

148 estimator.data = estimator_dict["metadata"] 

149 add_object_to_db(estimator, make_commit=False) 

150 return estimator 

151 

152 

153def prepare_create_or_update_pages_in_db(sim_id: int, estimator_dict): 

154 """Prepares page objects for insertion or update without committing to the database""" 

155 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"]) 

156 for page_dict in estimator_dict["pages"]: 

157 page = fetch_page_by_est_id_and_page_number(est_id=estimator.id, 

158 page_number=int(page_dict["metadata"]["page_number"])) 

159 page_existed = bool(page) 

160 if not page_existed: 

161 page = PageModel(page_number=int(page_dict["metadata"]["page_number"]), 

162 estimator_id=estimator.id, 

163 page_dimension=int(page_dict['dimensions']), 

164 page_name=str(page_dict["metadata"]["name"])) 

165 # we always update the data 

166 page.data = page_dict 

167 if not page_existed: 

168 # if page was created, we add it to the session 

169 add_object_to_db(page, make_commit=False) 

170 

171 

172def parse_page_numbers(param: str) -> List[int]: 

173 """Parses string of page ranges (e.g., '1-3,5') and returns a sorted list of page numbers""" 

174 pages = set() 

175 for part in param.split(','): 

176 if '-' in part: 

177 start, end = map(int, part.split('-')) 

178 pages.update(range(start, end + 1)) 

179 else: 

180 pages.add(int(part)) 

181 return sorted(pages) 

182 

183 

184class ResultsResource(Resource): 

185 """Class responsible for managing results""" 

186 

187 @staticmethod 

188 def post(): 

189 """ 

190 Method for saving results 

191 Used by the jobs at the end of simulation 

192 Structure required by this method to work properly: 

193 { 

194 "simulation_id": <int>, 

195 "update_key": <string>, 

196 "estimators": <dict> 

197 } 

198 """ 

199 payload_dict: dict = request.get_json(force=True) 

200 if {"simulation_id", "update_key", "estimators"} != set(payload_dict.keys()): 

201 return yaptide_response(message="Incomplete JSON data", code=400) 

202 

203 sim_id = payload_dict["simulation_id"] 

204 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

205 

206 if not simulation: 

207 return yaptide_response(message="Simulation does not exist", code=400) 

208 

209 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

210 if decoded_token != sim_id: 

211 return yaptide_response(message="Invalid update key", code=400) 

212 

213 if simulation.input_type == InputType.EDITOR.value: 

214 outputs = simulation.inputs[0].data["input_json"]["scoringManager"]["outputs"] 

215 sorted_estimator_names = sorted([output["name"] for output in outputs]) 

216 for output in outputs: 

217 name = output["name"] 

218 # estimator_dict is sorted alphabeticaly by names, 

219 # thats why we can match indexes from sorted_estimator_names 

220 estimator_dict_index = sorted_estimator_names.index(name) 

221 estimator_dict = payload_dict["estimators"][estimator_dict_index] 

222 prepare_create_or_update_estimator_in_db(sim_id=simulation.id, name=name, estimator_dict=estimator_dict) 

223 elif simulation.input_type == InputType.FILES.value: 

224 for estimator_dict in payload_dict["estimators"]: 

225 prepare_create_or_update_estimator_in_db(sim_id=simulation.id, 

226 name=estimator_dict["name"], 

227 estimator_dict=estimator_dict) 

228 

229 # commit estimators 

230 make_commit_to_db() 

231 

232 for estimator_dict in payload_dict["estimators"]: 

233 prepare_create_or_update_pages_in_db(sim_id=simulation.id, estimator_dict=estimator_dict) 

234 

235 # commit pages 

236 make_commit_to_db() 

237 

238 logging.debug("Marking simulation as completed") 

239 update_dict = {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

240 update_simulation_state(simulation=simulation, update_dict=update_dict) 

241 

242 logging.debug("Marking simulation tasks as completed") 

243 

244 return yaptide_response(message="Results saved", code=202) 

245 

246 class APIParametersSchema(Schema): 

247 """Class specifies API parameters""" 

248 

249 job_id = fields.String() 

250 estimator_name = fields.String(load_default=None) 

251 page_number = fields.Integer(load_default=None) 

252 page_numbers = fields.String(load_default=None) 

253 

254 @staticmethod 

255 @requires_auth() 

256 def get(user: UserModel): 

257 """Method returning job status and results. 

258 If `estimator_name` parameter is provided, 

259 the response will include results only for that specific estimator, 

260 otherwise it will return all estimators for the given job. 

261 If `page_number` or `page_numbers` are provided, the response will include only specific pages. 

262 """ 

263 schema = ResultsResource.APIParametersSchema() 

264 errors: dict[str, list[str]] = schema.validate(request.args) 

265 if errors: 

266 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

267 param_dict: dict = schema.load(request.args) 

268 

269 job_id = param_dict['job_id'] 

270 estimator_name = param_dict['estimator_name'] 

271 page_number = param_dict.get('page_number') 

272 page_numbers = param_dict.get('page_numbers') 

273 

274 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

275 if not is_owned: 

276 return yaptide_response(message=error_message, code=res_code) 

277 

278 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id) 

279 if not simulation_id: 

280 return yaptide_response(message="Simulation does not exist", code=404) 

281 

282 # if estimator name is provided, return specific estimator 

283 if estimator_name is None: 

284 return get_all_estimators(sim_id=simulation_id) 

285 

286 if page_number is None and page_numbers is None: 

287 return get_single_estimator(sim_id=simulation_id, estimator_name=estimator_name) 

288 

289 estimator_id = fetch_estimator_id_by_sim_id_and_est_name(sim_id=simulation_id, est_name=estimator_name) 

290 if page_number is not None: 

291 page = fetch_page_by_est_id_and_page_number(est_id=estimator_id, page_number=page_number) 

292 result = {"page": page.data} 

293 return yaptide_response(message="Page retrieved successfully", code=200, content=result) 

294 

295 if page_numbers is not None: 

296 parsed_page_numbers = parse_page_numbers(page_numbers) 

297 pages = fetch_pages_by_est_id_and_page_numbers(est_id=estimator_id, page_numbers=parsed_page_numbers) 

298 result = {"pages": [page.data for page in pages]} 

299 return yaptide_response(message="Pages retrieved successfully", code=200, content=result) 

300 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

301 

302 

303class InputsResource(Resource): 

304 """Class responsible for returning simulation input""" 

305 

306 class APIParametersSchema(Schema): 

307 """Class specifies API parameters""" 

308 

309 job_id = fields.String() 

310 

311 @staticmethod 

312 @requires_auth() 

313 def get(user: UserModel): 

314 """Method returning simulation input""" 

315 schema = InputsResource.APIParametersSchema() 

316 errors: dict[str, list[str]] = schema.validate(request.args) 

317 if errors: 

318 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

319 param_dict: dict = schema.load(request.args) 

320 job_id = param_dict['job_id'] 

321 

322 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

323 if not is_owned: 

324 return yaptide_response(message=error_message, code=res_code) 

325 

326 simulation = fetch_simulation_by_job_id(job_id=job_id) 

327 

328 input_model = fetch_input_by_sim_id(sim_id=simulation.id) 

329 if not input_model: 

330 return yaptide_response(message="Input of simulation is unavailable", code=404) 

331 

332 return yaptide_response(message="Input of simulation", code=200, content={"input": input_model.data}) 

333 

334 

335class LogfilesResource(Resource): 

336 """Class responsible for managing logfiles""" 

337 

338 @staticmethod 

339 def post(): 

340 """ 

341 Method for saving logfiles 

342 Used by the jobs when the simulation fails 

343 Structure required by this method to work properly: 

344 { 

345 "simulation_id": <int>, 

346 "update_key": <string>, 

347 "logfiles": <dict> 

348 } 

349 """ 

350 payload_dict: dict = request.get_json(force=True) 

351 if {"simulation_id", "update_key", "logfiles"} != set(payload_dict.keys()): 

352 return yaptide_response(message="Incomplete JSON data", code=400) 

353 

354 sim_id = payload_dict["simulation_id"] 

355 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

356 

357 if not simulation: 

358 return yaptide_response(message="Simulation does not exist", code=400) 

359 

360 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

361 if decoded_token != sim_id: 

362 return yaptide_response(message="Invalid update key", code=400) 

363 

364 logfiles = LogfilesModel(simulation_id=simulation.id) 

365 logfiles.data = payload_dict["logfiles"] 

366 add_object_to_db(logfiles) 

367 

368 return yaptide_response(message="Log files saved", code=202) 

369 

370 class APIParametersSchema(Schema): 

371 """Class specifies API parameters""" 

372 

373 job_id = fields.String() 

374 

375 @staticmethod 

376 @requires_auth() 

377 def get(user: UserModel): 

378 """Method returning job status and results""" 

379 schema = ResultsResource.APIParametersSchema() 

380 errors: dict[str, list[str]] = schema.validate(request.args) 

381 if errors: 

382 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

383 param_dict: dict = schema.load(request.args) 

384 

385 job_id = param_dict['job_id'] 

386 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

387 if not is_owned: 

388 return yaptide_response(message=error_message, code=res_code) 

389 

390 simulation = fetch_simulation_by_job_id(job_id=job_id) 

391 

392 logfile = fetch_logfiles_by_sim_id(sim_id=simulation.id) 

393 if not logfile: 

394 return yaptide_response(message="Logfiles are unavailable", code=404) 

395 

396 logging.debug("Returning logfiles from database") 

397 

398 return yaptide_response(message="Logfiles", code=200, content={"logfiles": logfile.data})