Coverage for yaptide/routes/celery_routes.py: 42%
156 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
1import logging
2from collections import Counter
3from datetime import datetime
5from flask import request
6from flask_restful import Resource
7from marshmallow import Schema, fields
8from uuid import uuid4
10from yaptide.celery.simulation_worker import celery_app
11from yaptide.celery.tasks import convert_input_files
12from yaptide.celery.utils.manage_tasks import (get_job_results, run_job)
13from yaptide.persistence.db_methods import (add_object_to_db, fetch_celery_simulation_by_job_id,
14 fetch_celery_tasks_by_sim_id, fetch_estimators_by_sim_id,
15 fetch_pages_by_estimator_id, make_commit_to_db, update_simulation_state,
16 update_task_state)
17from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, EstimatorModel, InputModel, PageModel,
18 UserModel)
19from yaptide.routes.utils.decorators import requires_auth
20from yaptide.routes.utils.response_templates import (error_validation_response, yaptide_response)
21from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
22from yaptide.routes.utils.tokens import encode_simulation_auth_token
23from yaptide.utils.enums import EntityState, PlatformType
24from yaptide.utils.helper_tasks import terminate_unfinished_tasks
27class JobsDirect(Resource):
28 """Class responsible for simulations run directly with celery"""
30 @staticmethod
31 @requires_auth()
32 def post(user: UserModel):
33 """Submit simulation job to celery"""
34 payload_dict: dict = request.get_json(force=True)
35 if not payload_dict:
36 return yaptide_response(message="No JSON in body", code=400)
38 required_keys = {"sim_type", "ntasks", "input_type"}
40 if required_keys != required_keys.intersection(set(payload_dict.keys())):
41 diff = required_keys.difference(set(payload_dict.keys()))
42 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
44 input_type = determine_input_type(payload_dict)
46 if input_type is None:
47 return error_validation_response()
49 # create a new simulation in the database, not waiting for the job to finish
50 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid4()) + PlatformType.DIRECT.value
51 simulation = CelerySimulationModel(user_id=user.id,
52 job_id=job_id,
53 sim_type=payload_dict["sim_type"],
54 input_type=input_type,
55 title=payload_dict.get("title", ''))
56 add_object_to_db(simulation)
57 update_key = encode_simulation_auth_token(simulation.id)
58 logging.info("Simulation %d created and inserted into DB", simulation.id)
59 logging.debug("Update key set to %s", update_key)
61 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
62 # create tasks in the database in the default PENDING state
63 celery_ids = [str(uuid4()) for _ in range(payload_dict["ntasks"])]
64 for i in range(payload_dict["ntasks"]):
65 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i, celery_id=celery_ids[i])
66 add_object_to_db(task, make_commit=False)
67 make_commit_to_db()
69 # submit the asynchronous job to celery
70 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"],
71 celery_ids, payload_dict["sim_type"])
73 input_model = InputModel(simulation_id=simulation.id)
74 input_model.data = input_dict
75 add_object_to_db(input_model)
76 if simulation.update_state({"job_state": EntityState.PENDING.value}):
77 make_commit_to_db()
79 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id})
81 class APIParametersSchema(Schema):
82 """Class specifies API parameters for GET and DELETE request"""
84 job_id = fields.String()
86 @staticmethod
87 @requires_auth()
88 def get(user: UserModel):
89 """Method returning job status and results"""
90 # validate request parameters and handle errors
91 schema = JobsDirect.APIParametersSchema()
92 errors: dict[str, list[str]] = schema.validate(request.args)
93 if errors:
94 return yaptide_response(message="Wrong parameters", code=400, content=errors)
95 param_dict: dict = schema.load(request.args)
97 # get job_id from request parameters and check if user owns this job
98 job_id = param_dict['job_id']
99 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
100 if not is_owned:
101 return yaptide_response(message=error_message, code=res_code)
103 # find appropriate simulation in the database
104 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
106 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
108 job_tasks_status = [task.get_status_dict() for task in tasks]
110 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
111 return yaptide_response(message=f"Job state: {simulation.job_state}",
112 code=200,
113 content={
114 "job_state": simulation.job_state,
115 "job_tasks_status": job_tasks_status,
116 })
118 job_info = {"job_state": simulation.job_state}
119 status_counter = Counter([task["task_state"] for task in job_tasks_status])
120 if status_counter[EntityState.PENDING.value] == len(job_tasks_status):
121 job_info["job_state"] = EntityState.PENDING.value
122 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status):
123 job_info["job_state"] = EntityState.FAILED.value
124 elif status_counter[EntityState.RUNNING.value] > 0:
125 job_info["job_state"] = EntityState.RUNNING.value
127 # if simulation is not found, return error
128 update_simulation_state(simulation=simulation, update_dict=job_info)
130 job_info["job_tasks_status"] = job_tasks_status
132 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info)
134 @staticmethod
135 @requires_auth()
136 def delete(user: UserModel):
137 """Method canceling simulation and returning status of this action"""
138 schema = JobsDirect.APIParametersSchema()
139 errors: dict[str, list[str]] = schema.validate(request.args)
140 if errors:
141 return error_validation_response(content=errors)
142 params_dict: dict = schema.load(request.args)
144 job_id = params_dict['job_id']
146 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
147 if not is_owned:
148 return yaptide_response(message=error_message, code=res_code)
150 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
152 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
153 EntityState.UNKNOWN.value):
154 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
155 code=200,
156 content={
157 "job_state": simulation.job_state,
158 })
160 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
161 celery_ids = [
162 task.celery_id for task in tasks
163 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value, EntityState.UNKNOWN.value]
164 ]
166 # The merge_id is canceled first because merge task starts after run simulation tasks are finished/canceled.
167 # We don't want it to run accidentally.
168 celery_app.control.revoke(simulation.merge_id, terminate=True, signal="SIGINT")
169 celery_app.control.revoke(celery_ids, terminate=True, signal="SIGINT")
170 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
171 for task in tasks:
172 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value]:
173 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
175 terminate_unfinished_tasks.delay(simulation_id=simulation.id)
176 return yaptide_response(message="Cancelled sucessfully", code=200)
179class ResultsDirect(Resource):
180 """Class responsible for returning simulation results"""
182 class APIParametersSchema(Schema):
183 """Class specifies API parameters"""
185 job_id = fields.String()
187 @staticmethod
188 @requires_auth()
189 def get(user: UserModel):
190 """Method returning job status and results"""
191 schema = ResultsDirect.APIParametersSchema()
192 errors: dict[str, list[str]] = schema.validate(request.args)
193 if errors:
194 return yaptide_response(message="Wrong parameters", code=400, content=errors)
195 param_dict: dict = schema.load(request.args)
197 job_id = param_dict['job_id']
198 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
199 if not is_owned:
200 return yaptide_response(message=error_message, code=res_code)
202 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
204 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id)
205 if len(estimators) > 0:
206 logging.debug("Returning results from database")
207 result_estimators = []
208 for estimator in estimators:
209 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id)
210 estimator_dict = {
211 "metadata": estimator.data,
212 "name": estimator.name,
213 "pages": [page.data for page in pages]
214 }
215 result_estimators.append(estimator_dict)
216 return yaptide_response(message=f"Results for job: {job_id}",
217 code=200,
218 content={"estimators": result_estimators})
220 result: dict = get_job_results(job_id=job_id)
221 if "estimators" not in result:
222 logging.debug("Results for job %s are unavailable", job_id)
223 return yaptide_response(message="Results are unavailable", code=404, content=result)
225 for estimator_dict in result["estimators"]:
226 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id)
227 estimator.data = estimator_dict["metadata"]
228 add_object_to_db(estimator)
229 for page_dict in estimator_dict["pages"]:
230 page = PageModel(estimator_id=estimator.id, page_number=int(page_dict["metadata"]["page_number"]))
231 page.data = page_dict
232 add_object_to_db(page, False)
233 make_commit_to_db()
235 logging.debug("Returning results from Celery")
236 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result)
239class ConvertResource(Resource):
240 """Class responsible for returning input_model files converted from front JSON"""
242 @staticmethod
243 @requires_auth()
244 def post(_: UserModel):
245 """Method handling input_model files convertion"""
246 payload_dict: dict = request.get_json(force=True)
247 if not payload_dict:
248 return yaptide_response(message="No JSON in body", code=400)
250 # Rework in later PRs to match pattern from jobs endpoint
251 job = convert_input_files.delay(payload_dict=payload_dict)
252 result: dict = job.wait()
254 return yaptide_response(message="Converted Input Files", code=200, content=result)