Coverage for yaptide/routes/celery_routes.py: 40%
151 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
1import logging
2import uuid
3from collections import Counter
4from datetime import datetime
6from flask import request
7from flask_restful import Resource
8from marshmallow import Schema, fields
10from yaptide.celery.tasks import convert_input_files
11from yaptide.celery.utils.manage_tasks import (cancel_job, get_job_results,
12 run_job)
13from yaptide.persistence.db_methods import (add_object_to_db,
14 fetch_celery_simulation_by_job_id,
15 fetch_celery_tasks_by_sim_id,
16 fetch_estimators_by_sim_id,
17 fetch_pages_by_estimator_id,
18 make_commit_to_db,
19 update_simulation_state,
20 update_task_state)
21from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel,
22 EstimatorModel, InputModel, PageModel,
23 UserModel)
24from yaptide.routes.utils.decorators import requires_auth
25from yaptide.routes.utils.response_templates import (error_internal_response,
26 error_validation_response,
27 yaptide_response)
28from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
29from yaptide.utils.enums import EntityState, PlatformType
32class JobsDirect(Resource):
33 """Class responsible for simulations run directly with celery"""
35 @staticmethod
36 @requires_auth()
37 def post(user: UserModel):
38 """Submit simulation job to celery"""
39 payload_dict: dict = request.get_json(force=True)
40 if not payload_dict:
41 return yaptide_response(message="No JSON in body", code=400)
43 required_keys = {"sim_type", "ntasks", "input_type"}
45 if required_keys != required_keys.intersection(set(payload_dict.keys())):
46 diff = required_keys.difference(set(payload_dict.keys()))
47 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
49 input_type = determine_input_type(payload_dict)
51 if input_type is None:
52 return error_validation_response()
54 # create a new simulation in the database, not waiting for the job to finish
55 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.DIRECT.value
56 simulation = CelerySimulationModel(user_id=user.id,
57 job_id=job_id,
58 sim_type=payload_dict["sim_type"],
59 input_type=input_type,
60 title=payload_dict.get("title", ''))
61 update_key = str(uuid.uuid4())
62 simulation.set_update_key(update_key)
63 add_object_to_db(simulation)
64 logging.info("Simulation %d created and inserted into DB", simulation.id)
65 logging.debug("Update key set to %s", update_key)
67 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
69 # submit the asynchronous job to celery
70 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"],
71 payload_dict["sim_type"])
73 # create tasks in the database in the default PENDING state
74 for i in range(payload_dict["ntasks"]):
75 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i)
76 add_object_to_db(task, make_commit=False)
78 input_model = InputModel(simulation_id=simulation.id)
79 input_model.data = input_dict
80 add_object_to_db(input_model)
81 if simulation.update_state({"job_state": EntityState.PENDING.value}):
82 make_commit_to_db()
84 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id})
86 class APIParametersSchema(Schema):
87 """Class specifies API parameters for GET and DELETE request"""
89 job_id = fields.String()
91 @staticmethod
92 @requires_auth()
93 def get(user: UserModel):
94 """Method returning job status and results"""
95 # validate request parameters and handle errors
96 schema = JobsDirect.APIParametersSchema()
97 errors: dict[str, list[str]] = schema.validate(request.args)
98 if errors:
99 return yaptide_response(message="Wrong parameters", code=400, content=errors)
100 param_dict: dict = schema.load(request.args)
102 # get job_id from request parameters and check if user owns this job
103 job_id = param_dict['job_id']
104 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
105 if not is_owned:
106 return yaptide_response(message=error_message, code=res_code)
108 # find appropriate simulation in the database
109 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
111 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
113 job_tasks_status = [task.get_status_dict() for task in tasks]
115 if simulation.job_state in (EntityState.COMPLETED.value,
116 EntityState.FAILED.value):
117 return yaptide_response(message=f"Job state: {simulation.job_state}",
118 code=200,
119 content={
120 "job_state": simulation.job_state,
121 "job_tasks_status": job_tasks_status,
122 })
124 job_info = {
125 "job_state": simulation.job_state
126 }
127 status_counter = Counter([task["task_state"] for task in job_tasks_status])
128 if status_counter[EntityState.PENDING.value] == len(job_tasks_status):
129 job_info["job_state"] = EntityState.PENDING.value
130 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status):
131 job_info["job_state"] = EntityState.FAILED.value
132 elif status_counter[EntityState.RUNNING.value] > 0:
133 job_info["job_state"] = EntityState.RUNNING.value
135 # if simulation is not found, return error
136 update_simulation_state(simulation=simulation, update_dict=job_info)
138 job_info["job_tasks_status"] = job_tasks_status
140 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info)
142 @staticmethod
143 @requires_auth()
144 def delete(user: UserModel):
145 """Method canceling simulation and returning status of this action"""
146 schema = JobsDirect.APIParametersSchema()
147 errors: dict[str, list[str]] = schema.validate(request.args)
148 if errors:
149 return error_validation_response(content=errors)
150 params_dict: dict = schema.load(request.args)
152 job_id = params_dict['job_id']
154 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(
155 job_id=job_id, user=user)
156 if not is_owned:
157 return yaptide_response(message=error_message, code=res_code)
159 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
161 if simulation.job_state in (EntityState.COMPLETED.value,
162 EntityState.FAILED.value,
163 EntityState.CANCELED.value,
164 EntityState.UNKNOWN.value):
165 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
166 code=200,
167 content={
168 "job_state": simulation.job_state,
169 })
171 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
173 celery_ids = [task.celery_id for task in tasks]
175 result: dict = cancel_job(merge_id=simulation.merge_id, celery_ids=celery_ids)
177 if "merge" in result:
178 update_simulation_state(simulation=simulation, update_dict=result["merge"])
179 for i, task in enumerate(tasks):
180 update_task_state(task=task, update_dict=result["tasks"][i])
182 return yaptide_response(message="", code=200, content=result)
184 return error_internal_response()
187class ResultsDirect(Resource):
188 """Class responsible for returning simulation results"""
190 class APIParametersSchema(Schema):
191 """Class specifies API parameters"""
193 job_id = fields.String()
195 @staticmethod
196 @requires_auth()
197 def get(user: UserModel):
198 """Method returning job status and results"""
199 schema = ResultsDirect.APIParametersSchema()
200 errors: dict[str, list[str]] = schema.validate(request.args)
201 if errors:
202 return yaptide_response(message="Wrong parameters", code=400, content=errors)
203 param_dict: dict = schema.load(request.args)
205 job_id = param_dict['job_id']
206 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
207 if not is_owned:
208 return yaptide_response(message=error_message, code=res_code)
210 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
212 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id)
213 if len(estimators) > 0:
214 logging.debug("Returning results from database")
215 result_estimators = []
216 for estimator in estimators:
217 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id)
218 estimator_dict = {
219 "metadata": estimator.data,
220 "name": estimator.name,
221 "pages": [page.data for page in pages]
222 }
223 result_estimators.append(estimator_dict)
224 return yaptide_response(message=f"Results for job: {job_id}",
225 code=200, content={"estimators": result_estimators})
227 result: dict = get_job_results(job_id=job_id)
228 if "estimators" not in result:
229 logging.debug("Results for job %s are unavailable", job_id)
230 return yaptide_response(message="Results are unavailable", code=404, content=result)
232 for estimator_dict in result["estimators"]:
233 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id)
234 estimator.data = estimator_dict["metadata"]
235 add_object_to_db(estimator)
236 for page_dict in estimator_dict["pages"]:
237 page = PageModel(estimator_id=estimator.id,
238 page_number=int(page_dict["metadata"]["page_number"]))
239 page.data = page_dict
240 add_object_to_db(page, False)
241 make_commit_to_db()
243 logging.debug("Returning results from Celery")
244 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result)
247class ConvertResource(Resource):
248 """Class responsible for returning input_model files converted from front JSON"""
250 @staticmethod
251 @requires_auth()
252 def post(_: UserModel):
253 """Method handling input_model files convertion"""
254 payload_dict: dict = request.get_json(force=True)
255 if not payload_dict:
256 return yaptide_response(message="No JSON in body", code=400)
258 # Rework in later PRs to match pattern from jobs endpoint
259 job = convert_input_files.delay(payload_dict=payload_dict)
260 result: dict = job.wait()
262 return yaptide_response(message="Converted Input Files", code=200, content=result)