Coverage for yaptide/routes/common_sim_routes.py: 61%
237 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import logging
2from collections import Counter
3from datetime import datetime
4from typing import List
6from flask import request, current_app as app
7from flask_restful import Resource
8from marshmallow import Schema, fields
10from yaptide.persistence.db_methods import (
11 add_object_to_db, fetch_estimator_by_sim_id_and_est_name, fetch_estimator_by_sim_id_and_file_name,
12 fetch_estimator_id_by_sim_id_and_est_name, fetch_estimators_by_sim_id, fetch_input_by_sim_id,
13 fetch_logfiles_by_sim_id, fetch_page_by_est_id_and_page_number, fetch_pages_by_est_id_and_page_numbers,
14 fetch_pages_by_estimator_id, fetch_simulation_by_job_id, fetch_simulation_by_sim_id, fetch_simulation_id_by_job_id,
15 fetch_tasks_by_sim_id, make_commit_to_db, update_simulation_state)
16from yaptide.persistence.models import (EstimatorModel, LogfilesModel, PageModel, UserModel)
17from yaptide.routes.utils.decorators import requires_auth
18from yaptide.routes.utils.response_templates import yaptide_response
19from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist
20from yaptide.routes.utils.tokens import decode_auth_token
21from yaptide.utils.enums import EntityState, InputType
24class JobsResource(Resource):
25 """Class responsible for managing common jobs"""
27 class APIParametersSchema(Schema):
28 """Class specifies API parameters for GET and DELETE request"""
30 job_id = fields.String()
32 @staticmethod
33 @requires_auth()
34 def get(user: UserModel):
35 """Method returning info about job"""
36 schema = JobsResource.APIParametersSchema()
37 errors: dict[str, list[str]] = schema.validate(request.args)
38 if errors:
39 return yaptide_response(message="Wrong parameters", code=400, content=errors)
40 param_dict: dict = schema.load(request.args)
42 # get job_id from request parameters and check if user owns this job
43 job_id = param_dict['job_id']
44 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
45 if not is_owned:
46 return yaptide_response(message=error_message, code=res_code)
48 simulation = fetch_simulation_by_job_id(job_id=job_id)
49 if simulation.job_state == EntityState.UNKNOWN.value:
50 return yaptide_response(message="Job state is unknown",
51 code=200,
52 content={"job_state": simulation.job_state})
54 tasks = fetch_tasks_by_sim_id(sim_id=simulation.id)
56 job_tasks_status = [task.get_status_dict() for task in tasks]
58 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value,
59 EntityState.MERGING_QUEUED.value, EntityState.MERGING_RUNNING.value):
60 return yaptide_response(message=f"Job state: {simulation.job_state}",
61 code=200,
62 content={
63 "job_state": simulation.job_state,
64 "job_tasks_status": job_tasks_status,
65 })
67 job_info = {"job_state": simulation.job_state}
68 status_counter = Counter([task["task_state"] for task in job_tasks_status])
69 if status_counter[EntityState.PENDING.value] == len(job_tasks_status):
70 job_info["job_state"] = EntityState.PENDING.value
71 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status):
72 job_info["job_state"] = EntityState.FAILED.value
73 elif status_counter[EntityState.RUNNING.value] > 0:
74 job_info["job_state"] = EntityState.RUNNING.value
75 elif job_id.endswith("BATCH") and status_counter[EntityState.COMPLETED.value] == len(job_tasks_status):
76 job_info["job_state"] = EntityState.MERGING_QUEUED.value
78 update_simulation_state(simulation=simulation, update_dict=job_info)
80 job_info["job_tasks_status"] = job_tasks_status
82 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info)
84 @staticmethod
85 def post():
86 """Handles requests for updating simulation informations in db"""
87 payload_dict: dict = request.get_json(force=True)
88 sim_id: int = payload_dict["sim_id"]
89 app.logger.info(f"sim_id {sim_id}")
90 simulation = fetch_simulation_by_sim_id(sim_id=sim_id)
92 if not simulation:
93 app.logger.info(f"sim_id {sim_id} simulation not found ")
94 return yaptide_response(message=f"Simulation {sim_id} does not exist", code=501)
96 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id")
97 if decoded_token != sim_id:
98 return yaptide_response(message="Invalid update key", code=400)
100 update_simulation_state(simulation, payload_dict)
101 if payload_dict.get("log"):
102 logfiles = LogfilesModel(simulation_id=simulation.id)
103 logfiles.data = payload_dict["log"]
104 add_object_to_db(logfiles)
106 return yaptide_response(message="Task updated", code=202)
109def get_single_estimator(sim_id: int, estimator_name: str):
110 """Retrieve a single estimator by simulation ID and estimator name"""
111 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_name)
113 if not estimator:
114 return yaptide_response(message="Estimator not found", code=404)
116 pages = fetch_pages_by_estimator_id(est_id=estimator.id)
117 estimator_dict = {"metadata": estimator.data, "name": estimator.name, "pages": [page.data for page in pages]}
118 return yaptide_response(message=f"Estimator '{estimator_name}' for simulation: {sim_id}",
119 code=200,
120 content=estimator_dict)
123def get_all_estimators(sim_id: int):
124 """Retrieve all estimators for a given simulation ID"""
125 estimators = fetch_estimators_by_sim_id(sim_id=sim_id)
126 if len(estimators) == 0:
127 return yaptide_response(message="Results are unavailable", code=404)
129 logging.debug("Returning results from database")
130 result_estimators = []
131 for estimator in estimators:
132 estimator_dict = {
133 "metadata": estimator.data,
134 "name": estimator.name,
135 "pages": [page.data for page in estimator.pages]
136 }
137 result_estimators.append(estimator_dict)
138 return yaptide_response(message=f"Results for simulation: {sim_id}",
139 code=200,
140 content={"estimators": result_estimators})
143def prepare_create_or_update_estimator_in_db(sim_id: int, name: str, estimator_dict: dict):
144 """Prepares an estimator object for insertion or update without committing to the database"""
145 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"])
146 if not estimator:
147 estimator = EstimatorModel(name=name, file_name=estimator_dict["name"], simulation_id=sim_id)
148 estimator.data = estimator_dict["metadata"]
149 add_object_to_db(estimator, make_commit=False)
150 return estimator
153def prepare_create_or_update_pages_in_db(sim_id: int, estimator_dict):
154 """Prepares page objects for insertion or update without committing to the database"""
155 estimator = fetch_estimator_by_sim_id_and_file_name(sim_id=sim_id, file_name=estimator_dict["name"])
156 for page_dict in estimator_dict["pages"]:
157 page = fetch_page_by_est_id_and_page_number(est_id=estimator.id,
158 page_number=int(page_dict["metadata"]["page_number"]))
159 page_existed = bool(page)
160 if not page_existed:
161 page = PageModel(page_number=int(page_dict["metadata"]["page_number"]),
162 estimator_id=estimator.id,
163 page_dimension=int(page_dict['dimensions']),
164 page_name=str(page_dict["metadata"]["name"]))
165 # we always update the data
166 page.data = page_dict
167 if not page_existed:
168 # if page was created, we add it to the session
169 add_object_to_db(page, make_commit=False)
172def parse_page_numbers(param: str) -> List[int]:
173 """Parses string of page ranges (e.g., '1-3,5') and returns a sorted list of page numbers"""
174 pages = set()
175 for part in param.split(','):
176 if '-' in part:
177 start, end = map(int, part.split('-'))
178 pages.update(range(start, end + 1))
179 else:
180 pages.add(int(part))
181 return sorted(pages)
184class ResultsResource(Resource):
185 """Class responsible for managing results"""
187 @staticmethod
188 def post():
189 """
190 Method for saving results
191 Used by the jobs at the end of simulation
192 Structure required by this method to work properly:
193 {
194 "simulation_id": <int>,
195 "update_key": <string>,
196 "estimators": <dict>
197 }
198 """
199 payload_dict: dict = request.get_json(force=True)
200 if {"simulation_id", "update_key", "estimators"} != set(payload_dict.keys()):
201 return yaptide_response(message="Incomplete JSON data", code=400)
203 sim_id = payload_dict["simulation_id"]
204 simulation = fetch_simulation_by_sim_id(sim_id=sim_id)
206 if not simulation:
207 return yaptide_response(message="Simulation does not exist", code=400)
209 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id")
210 if decoded_token != sim_id:
211 return yaptide_response(message="Invalid update key", code=400)
213 if simulation.input_type == InputType.EDITOR.value:
214 outputs = simulation.inputs[0].data["input_json"]["scoringManager"]["outputs"]
215 sorted_estimator_names = sorted([output["name"] for output in outputs])
216 for output in outputs:
217 name = output["name"]
218 # estimator_dict is sorted alphabeticaly by names,
219 # thats why we can match indexes from sorted_estimator_names
220 estimator_dict_index = sorted_estimator_names.index(name)
221 estimator_dict = payload_dict["estimators"][estimator_dict_index]
222 prepare_create_or_update_estimator_in_db(sim_id=simulation.id, name=name, estimator_dict=estimator_dict)
223 elif simulation.input_type == InputType.FILES.value:
224 for estimator_dict in payload_dict["estimators"]:
225 prepare_create_or_update_estimator_in_db(sim_id=simulation.id,
226 name=estimator_dict["name"],
227 estimator_dict=estimator_dict)
229 # commit estimators
230 make_commit_to_db()
232 for estimator_dict in payload_dict["estimators"]:
233 prepare_create_or_update_pages_in_db(sim_id=simulation.id, estimator_dict=estimator_dict)
235 # commit pages
236 make_commit_to_db()
238 logging.debug("Marking simulation as completed")
239 update_dict = {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")}
240 update_simulation_state(simulation=simulation, update_dict=update_dict)
242 logging.debug("Marking simulation tasks as completed")
244 return yaptide_response(message="Results saved", code=202)
246 class APIParametersSchema(Schema):
247 """Class specifies API parameters"""
249 job_id = fields.String()
250 estimator_name = fields.String(load_default=None)
251 page_number = fields.Integer(load_default=None)
252 page_numbers = fields.String(load_default=None)
254 @staticmethod
255 @requires_auth()
256 def get(user: UserModel):
257 """Method returning job status and results.
258 If `estimator_name` parameter is provided,
259 the response will include results only for that specific estimator,
260 otherwise it will return all estimators for the given job.
261 If `page_number` or `page_numbers` are provided, the response will include only specific pages.
262 """
263 schema = ResultsResource.APIParametersSchema()
264 errors: dict[str, list[str]] = schema.validate(request.args)
265 if errors:
266 return yaptide_response(message="Wrong parameters", code=400, content=errors)
267 param_dict: dict = schema.load(request.args)
269 job_id = param_dict['job_id']
270 estimator_name = param_dict['estimator_name']
271 page_number = param_dict.get('page_number')
272 page_numbers = param_dict.get('page_numbers')
274 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
275 if not is_owned:
276 return yaptide_response(message=error_message, code=res_code)
278 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id)
279 if not simulation_id:
280 return yaptide_response(message="Simulation does not exist", code=404)
282 # if estimator name is provided, return specific estimator
283 if estimator_name is None:
284 return get_all_estimators(sim_id=simulation_id)
286 if page_number is None and page_numbers is None:
287 return get_single_estimator(sim_id=simulation_id, estimator_name=estimator_name)
289 estimator_id = fetch_estimator_id_by_sim_id_and_est_name(sim_id=simulation_id, est_name=estimator_name)
290 if page_number is not None:
291 page = fetch_page_by_est_id_and_page_number(est_id=estimator_id, page_number=page_number)
292 result = {"page": page.data}
293 return yaptide_response(message="Page retrieved successfully", code=200, content=result)
295 if page_numbers is not None:
296 parsed_page_numbers = parse_page_numbers(page_numbers)
297 pages = fetch_pages_by_est_id_and_page_numbers(est_id=estimator_id, page_numbers=parsed_page_numbers)
298 result = {"pages": [page.data for page in pages]}
299 return yaptide_response(message="Pages retrieved successfully", code=200, content=result)
300 return yaptide_response(message="Wrong parameters", code=400, content=errors)
303class InputsResource(Resource):
304 """Class responsible for returning simulation input"""
306 class APIParametersSchema(Schema):
307 """Class specifies API parameters"""
309 job_id = fields.String()
311 @staticmethod
312 @requires_auth()
313 def get(user: UserModel):
314 """Method returning simulation input"""
315 schema = InputsResource.APIParametersSchema()
316 errors: dict[str, list[str]] = schema.validate(request.args)
317 if errors:
318 return yaptide_response(message="Wrong parameters", code=400, content=errors)
319 param_dict: dict = schema.load(request.args)
320 job_id = param_dict['job_id']
322 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
323 if not is_owned:
324 return yaptide_response(message=error_message, code=res_code)
326 simulation = fetch_simulation_by_job_id(job_id=job_id)
328 input_model = fetch_input_by_sim_id(sim_id=simulation.id)
329 if not input_model:
330 return yaptide_response(message="Input of simulation is unavailable", code=404)
332 return yaptide_response(message="Input of simulation", code=200, content={"input": input_model.data})
335class LogfilesResource(Resource):
336 """Class responsible for managing logfiles"""
338 @staticmethod
339 def post():
340 """
341 Method for saving logfiles
342 Used by the jobs when the simulation fails
343 Structure required by this method to work properly:
344 {
345 "simulation_id": <int>,
346 "update_key": <string>,
347 "logfiles": <dict>
348 }
349 """
350 payload_dict: dict = request.get_json(force=True)
351 if {"simulation_id", "update_key", "logfiles"} != set(payload_dict.keys()):
352 return yaptide_response(message="Incomplete JSON data", code=400)
354 sim_id = payload_dict["simulation_id"]
355 simulation = fetch_simulation_by_sim_id(sim_id=sim_id)
357 if not simulation:
358 return yaptide_response(message="Simulation does not exist", code=400)
360 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id")
361 if decoded_token != sim_id:
362 return yaptide_response(message="Invalid update key", code=400)
364 logfiles = LogfilesModel(simulation_id=simulation.id)
365 logfiles.data = payload_dict["logfiles"]
366 add_object_to_db(logfiles)
368 return yaptide_response(message="Log files saved", code=202)
370 class APIParametersSchema(Schema):
371 """Class specifies API parameters"""
373 job_id = fields.String()
375 @staticmethod
376 @requires_auth()
377 def get(user: UserModel):
378 """Method returning job status and results"""
379 schema = ResultsResource.APIParametersSchema()
380 errors: dict[str, list[str]] = schema.validate(request.args)
381 if errors:
382 return yaptide_response(message="Wrong parameters", code=400, content=errors)
383 param_dict: dict = schema.load(request.args)
385 job_id = param_dict['job_id']
386 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
387 if not is_owned:
388 return yaptide_response(message=error_message, code=res_code)
390 simulation = fetch_simulation_by_job_id(job_id=job_id)
392 logfile = fetch_logfiles_by_sim_id(sim_id=simulation.id)
393 if not logfile:
394 return yaptide_response(message="Logfiles are unavailable", code=404)
396 logging.debug("Returning logfiles from database")
398 return yaptide_response(message="Logfiles", code=200, content={"logfiles": logfile.data})