Coverage for yaptide/routes/common_sim_routes.py: 61%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-06-10 10:09 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4from typing import List 

5 

6from flask import request, current_app as app 

7from flask_restful import Resource 

8from marshmallow import Schema, fields 

9 

10from yaptide.persistence.db_methods import ( 

11 add_object_to_db, fetch_estimator_by_sim_id_and_est_name, fetch_estimator_by_sim_id_and_file_name, 

12 fetch_estimator_id_by_sim_id_and_est_name, fetch_estimators_by_sim_id, fetch_input_by_sim_id, 

13 fetch_logfiles_by_sim_id, fetch_page_by_est_id_and_page_number, fetch_pages_by_est_id_and_page_numbers, 

14 fetch_pages_by_estimator_id, fetch_simulation_by_job_id, fetch_simulation_by_sim_id, fetch_simulation_id_by_job_id, 

15 fetch_tasks_by_sim_id, make_commit_to_db, update_simulation_state) 

16from yaptide.persistence.models import (EstimatorModel, LogfilesModel, PageModel, UserModel) 

17from yaptide.routes.utils.decorators import requires_auth 

18from yaptide.routes.utils.response_templates import yaptide_response 

19from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist 

20from yaptide.routes.utils.tokens import decode_auth_token 

21from yaptide.utils.enums import EntityState, InputType 

22 

23 

24class JobsResource(Resource): 

25 """Class responsible for managing common jobs""" 

26 

27 class APIParametersSchema(Schema): 

28 """Class specifies API parameters for GET and DELETE request""" 

29 

30 job_id = fields.String() 

31 

32 @staticmethod 

33 @requires_auth() 

34 def get(user: UserModel): 

35 """Method returning info about job""" 

36 schema = JobsResource.APIParametersSchema() 

37 errors: dict[str, list[str]] = schema.validate(request.args) 

38 if errors: 

39 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

40 param_dict: dict = schema.load(request.args) 

41 

42 # get job_id from request parameters and check if user owns this job 

43 job_id = param_dict['job_id'] 

44 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

45 if not is_owned: 

46 return yaptide_response(message=error_message, code=res_code) 

47 

48 simulation = fetch_simulation_by_job_id(job_id=job_id) 

49 if simulation.job_state == EntityState.UNKNOWN.value: 

50 return yaptide_response(message="Job state is unknown", 

51 code=200, 

52 content={"job_state": simulation.job_state}) 

53 

54 tasks = fetch_tasks_by_sim_id(sim_id=simulation.id) 

55 

56 job_tasks_status = [task.get_status_dict() for task in tasks] 

57 job_tasks_status = sorted(job_tasks_status, key=lambda x: x["task_id"]) 

58 

59 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, 

60 EntityState.MERGING_QUEUED.value, EntityState.MERGING_RUNNING.value): 

61 return yaptide_response(message=f"Job state: {simulation.job_state}", 

62 code=200, 

63 content={ 

64 "job_state": simulation.job_state, 

65 "job_tasks_status": job_tasks_status, 

66 }) 

67 

68 job_info = {"job_state": simulation.job_state} 

69 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

70 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

71 job_info["job_state"] = EntityState.PENDING.value 

72 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

73 job_info["job_state"] = EntityState.FAILED.value 

74 elif status_counter[EntityState.RUNNING.value] > 0: 

75 job_info["job_state"] = EntityState.RUNNING.value 

76 elif job_id.endswith("BATCH") and status_counter[EntityState.COMPLETED.value] == len(job_tasks_status): 

77 job_info["job_state"] = EntityState.MERGING_QUEUED.value 

78 

79 update_simulation_state(simulation=simulation, update_dict=job_info) 

80 

81 job_info["job_tasks_status"] = job_tasks_status 

82 

83 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

84 

85 @staticmethod 

86 def post(): 

87 """Handles requests for updating simulation informations in db""" 

88 payload_dict: dict = request.get_json(force=True) 

89 sim_id: int = payload_dict["sim_id"] 

90 app.logger.info(f"sim_id {sim_id}") 

91 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

92 

93 if not simulation: 

94 app.logger.info(f"sim_id {sim_id} simulation not found ") 

95 return yaptide_response(message=f"Simulation {sim_id} does not exist", code=501) 

96 

97 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

98 if decoded_token != sim_id: 

99 return yaptide_response(message="Invalid update key", code=400) 

100 

101 update_simulation_state(simulation, payload_dict) 

102 if payload_dict.get("log"): 

103 logfiles = LogfilesModel(simulation_id=simulation.id) 

104 logfiles.data = payload_dict["log"] 

105 add_object_to_db(logfiles) 

106 

107 return yaptide_response(message="Task updated", code=202) 

108 

109 

110def get_single_estimator(sim_id: int, estimator_name: str): 

111 """Retrieve a single estimator by simulation ID and estimator name""" 

112 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_name) 

113 

114 if not estimator: 

115 return yaptide_response(message="Estimator not found", code=404) 

116 

117 pages = fetch_pages_by_estimator_id(est_id=estimator.id) 

118 estimator_dict = {"metadata": estimator.data, "name": estimator.name, "pages": [page.data for page in pages]} 

119 return yaptide_response(message=f"Estimator '{estimator_name}' for simulation: {sim_id}", 

120 code=200, 

121 content=estimator_dict) 

122 

123 

124def get_all_estimators(sim_id: int): 

125 """Retrieve all estimators for a given simulation ID""" 

126 estimators = fetch_estimators_by_sim_id(sim_id=sim_id) 

127 if len(estimators) == 0: 

128 return yaptide_response(message="Results are unavailable", code=404) 

129 

130 logging.debug("Returning results from database") 

131 result_estimators = [] 

132 for estimator in estimators: 

133 estimator_dict = { 

134 "metadata": estimator.data, 

135 "name": estimator.name, 

136 "pages": [page.data for page in estimator.pages] 

137 } 

138 result_estimators.append(estimator_dict) 

139 return yaptide_response(message=f"Results for simulation: {sim_id}", 

140 code=200, 

141 content={"estimators": result_estimators}) 

142 

143 

144def prepare_create_or_update_estimator_in_db(sim_id: int, name: str, estimator_dict: dict): 

145 """Prepares an estimator object for insertion or update without committing to the database""" 

146 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"]) 

147 if not estimator: 

148 estimator = EstimatorModel(name=name, file_name=estimator_dict["name"], simulation_id=sim_id) 

149 estimator.data = estimator_dict["metadata"] 

150 add_object_to_db(estimator, make_commit=False) 

151 return estimator 

152 

153 

154def prepare_create_or_update_pages_in_db(sim_id: int, estimator_dict): 

155 """Prepares page objects for insertion or update without committing to the database""" 

156 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"]) 

157 for page_dict in estimator_dict["pages"]: 

158 page = fetch_page_by_est_id_and_page_number(est_id=estimator.id, 

159 page_number=int(page_dict["metadata"]["page_number"])) 

160 page_existed = bool(page) 

161 if not page_existed: 

162 page = PageModel(page_number=int(page_dict["metadata"]["page_number"]), 

163 estimator_id=estimator.id, 

164 page_dimension=int(page_dict['dimensions']), 

165 page_name=str(page_dict["metadata"]["name"])) 

166 # we always update the data 

167 page.data = page_dict 

168 if not page_existed: 

169 # if page was created, we add it to the session 

170 add_object_to_db(page, make_commit=False) 

171 

172 

173def parse_page_numbers(param: str) -> List[int]: 

174 """Parses string of page ranges (e.g., '1-3,5') and returns a sorted list of page numbers""" 

175 pages = set() 

176 for part in param.split(','): 

177 if '-' in part: 

178 start, end = map(int, part.split('-')) 

179 pages.update(range(start, end + 1)) 

180 else: 

181 pages.add(int(part)) 

182 return sorted(pages) 

183 

184 

185class ResultsResource(Resource): 

186 """Class responsible for managing results""" 

187 

188 @staticmethod 

189 def post(): 

190 """ 

191 Method for saving results 

192 Used by the jobs at the end of simulation 

193 Structure required by this method to work properly: 

194 { 

195 "simulation_id": <int>, 

196 "update_key": <string>, 

197 "estimators": <dict> 

198 } 

199 """ 

200 payload_dict: dict = request.get_json(force=True) 

201 if {"simulation_id", "update_key", "estimators"} != set(payload_dict.keys()): 

202 return yaptide_response(message="Incomplete JSON data", code=400) 

203 

204 sim_id = payload_dict["simulation_id"] 

205 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

206 

207 if not simulation: 

208 return yaptide_response(message="Simulation does not exist", code=400) 

209 

210 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

211 if decoded_token != sim_id: 

212 return yaptide_response(message="Invalid update key", code=400) 

213 

214 if simulation.input_type == InputType.EDITOR.value: 

215 outputs = simulation.inputs[0].data["input_json"]["scoringManager"]["outputs"] 

216 sorted_estimator_names = sorted([output["name"] for output in outputs]) 

217 for output in outputs: 

218 name = output["name"] 

219 # estimator_dict is sorted alphabeticaly by names, 

220 # thats why we can match indexes from sorted_estimator_names 

221 estimator_dict_index = sorted_estimator_names.index(name) 

222 estimator_dict = payload_dict["estimators"][estimator_dict_index] 

223 prepare_create_or_update_estimator_in_db(sim_id=simulation.id, name=name, estimator_dict=estimator_dict) 

224 elif simulation.input_type == InputType.FILES.value: 

225 for estimator_dict in payload_dict["estimators"]: 

226 prepare_create_or_update_estimator_in_db(sim_id=simulation.id, 

227 name=estimator_dict["name"], 

228 estimator_dict=estimator_dict) 

229 

230 # commit estimators 

231 make_commit_to_db() 

232 

233 for estimator_dict in payload_dict["estimators"]: 

234 prepare_create_or_update_pages_in_db(sim_id=simulation.id, estimator_dict=estimator_dict) 

235 

236 # commit pages 

237 make_commit_to_db() 

238 

239 logging.debug("Marking simulation as completed") 

240 update_dict = {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

241 update_simulation_state(simulation=simulation, update_dict=update_dict) 

242 

243 logging.debug("Marking simulation tasks as completed") 

244 

245 return yaptide_response(message="Results saved", code=202) 

246 

247 class APIParametersSchema(Schema): 

248 """Class specifies API parameters""" 

249 

250 job_id = fields.String() 

251 estimator_name = fields.String(load_default=None) 

252 page_number = fields.Integer(load_default=None) 

253 page_numbers = fields.String(load_default=None) 

254 

255 @staticmethod 

256 @requires_auth() 

257 def get(user: UserModel): 

258 """Method returning job status and results. 

259 If `estimator_name` parameter is provided, 

260 the response will include results only for that specific estimator, 

261 otherwise it will return all estimators for the given job. 

262 If `page_number` or `page_numbers` are provided, the response will include only specific pages. 

263 """ 

264 schema = ResultsResource.APIParametersSchema() 

265 errors: dict[str, list[str]] = schema.validate(request.args) 

266 if errors: 

267 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

268 param_dict: dict = schema.load(request.args) 

269 

270 job_id = param_dict['job_id'] 

271 estimator_name = param_dict['estimator_name'] 

272 page_number = param_dict.get('page_number') 

273 page_numbers = param_dict.get('page_numbers') 

274 

275 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

276 if not is_owned: 

277 return yaptide_response(message=error_message, code=res_code) 

278 

279 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id) 

280 if not simulation_id: 

281 return yaptide_response(message="Simulation does not exist", code=404) 

282 

283 # if estimator name is provided, return specific estimator 

284 if estimator_name is None: 

285 return get_all_estimators(sim_id=simulation_id) 

286 

287 if page_number is None and page_numbers is None: 

288 return get_single_estimator(sim_id=simulation_id, estimator_name=estimator_name) 

289 

290 estimator_id = fetch_estimator_id_by_sim_id_and_est_name(sim_id=simulation_id, est_name=estimator_name) 

291 if page_number is not None: 

292 page = fetch_page_by_est_id_and_page_number(est_id=estimator_id, page_number=page_number) 

293 result = {"page": page.data} 

294 return yaptide_response(message="Page retrieved successfully", code=200, content=result) 

295 

296 if page_numbers is not None: 

297 parsed_page_numbers = parse_page_numbers(page_numbers) 

298 pages = fetch_pages_by_est_id_and_page_numbers(est_id=estimator_id, page_numbers=parsed_page_numbers) 

299 result = {"pages": [page.data for page in pages]} 

300 return yaptide_response(message="Pages retrieved successfully", code=200, content=result) 

301 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

302 

303 

304class InputsResource(Resource): 

305 """Class responsible for returning simulation input""" 

306 

307 class APIParametersSchema(Schema): 

308 """Class specifies API parameters""" 

309 

310 job_id = fields.String() 

311 

312 @staticmethod 

313 @requires_auth() 

314 def get(user: UserModel): 

315 """Method returning simulation input""" 

316 schema = InputsResource.APIParametersSchema() 

317 errors: dict[str, list[str]] = schema.validate(request.args) 

318 if errors: 

319 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

320 param_dict: dict = schema.load(request.args) 

321 job_id = param_dict['job_id'] 

322 

323 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

324 if not is_owned: 

325 return yaptide_response(message=error_message, code=res_code) 

326 

327 simulation = fetch_simulation_by_job_id(job_id=job_id) 

328 

329 input_model = fetch_input_by_sim_id(sim_id=simulation.id) 

330 if not input_model: 

331 return yaptide_response(message="Input of simulation is unavailable", code=404) 

332 

333 return yaptide_response(message="Input of simulation", code=200, content={"input": input_model.data}) 

334 

335 

336class LogfilesResource(Resource): 

337 """Class responsible for managing logfiles""" 

338 

339 @staticmethod 

340 def post(): 

341 """ 

342 Method for saving logfiles 

343 Used by the jobs when the simulation fails 

344 Structure required by this method to work properly: 

345 { 

346 "simulation_id": <int>, 

347 "update_key": <string>, 

348 "logfiles": <dict> 

349 } 

350 """ 

351 payload_dict: dict = request.get_json(force=True) 

352 if {"simulation_id", "update_key", "logfiles"} != set(payload_dict.keys()): 

353 return yaptide_response(message="Incomplete JSON data", code=400) 

354 

355 sim_id = payload_dict["simulation_id"] 

356 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

357 

358 if not simulation: 

359 return yaptide_response(message="Simulation does not exist", code=400) 

360 

361 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

362 if decoded_token != sim_id: 

363 return yaptide_response(message="Invalid update key", code=400) 

364 

365 logfiles = LogfilesModel(simulation_id=simulation.id) 

366 logfiles.data = payload_dict["logfiles"] 

367 add_object_to_db(logfiles) 

368 

369 return yaptide_response(message="Log files saved", code=202) 

370 

371 class APIParametersSchema(Schema): 

372 """Class specifies API parameters""" 

373 

374 job_id = fields.String() 

375 

376 @staticmethod 

377 @requires_auth() 

378 def get(user: UserModel): 

379 """Method returning job status and results""" 

380 schema = ResultsResource.APIParametersSchema() 

381 errors: dict[str, list[str]] = schema.validate(request.args) 

382 if errors: 

383 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

384 param_dict: dict = schema.load(request.args) 

385 

386 job_id = param_dict['job_id'] 

387 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

388 if not is_owned: 

389 return yaptide_response(message=error_message, code=res_code) 

390 

391 simulation = fetch_simulation_by_job_id(job_id=job_id) 

392 

393 logfile = fetch_logfiles_by_sim_id(sim_id=simulation.id) 

394 if not logfile: 

395 return yaptide_response(message="Logfiles are unavailable", code=404) 

396 

397 logging.debug("Returning logfiles from database") 

398 

399 return yaptide_response(message="Logfiles", code=200, content={"logfiles": logfile.data})