Coverage for yaptide/routes/celery_routes.py: 41%
145 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-03-31 19:18 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-03-31 19:18 +0000
1import logging
2from collections import Counter
3from datetime import datetime
5from flask import request
6from flask_restful import Resource
7from marshmallow import Schema, fields
8from uuid import uuid4
10from yaptide.celery.simulation_worker import celery_app
11from yaptide.celery.utils.manage_tasks import (get_job_results, run_job)
12from yaptide.persistence.db_methods import (add_object_to_db, fetch_celery_simulation_by_job_id,
13 fetch_celery_tasks_by_sim_id, fetch_estimators_by_sim_id,
14 fetch_pages_by_estimator_id, make_commit_to_db, update_simulation_state,
15 update_task_state)
16from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, EstimatorModel, InputModel, PageModel,
17 UserModel)
18from yaptide.routes.utils.decorators import requires_auth
19from yaptide.routes.utils.response_templates import (error_validation_response, yaptide_response)
20from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
21from yaptide.routes.utils.tokens import encode_simulation_auth_token
22from yaptide.utils.enums import EntityState, PlatformType
23from yaptide.utils.helper_tasks import terminate_unfinished_tasks
26class JobsDirect(Resource):
27 """Class responsible for simulations run directly with celery"""
29 @staticmethod
30 @requires_auth()
31 def post(user: UserModel):
32 """Submit simulation job to celery"""
33 payload_dict: dict = request.get_json(force=True)
34 if not payload_dict:
35 return yaptide_response(message="No JSON in body", code=400)
37 required_keys = {"sim_type", "ntasks", "input_type"}
39 if required_keys != required_keys.intersection(set(payload_dict.keys())):
40 diff = required_keys.difference(set(payload_dict.keys()))
41 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
43 input_type = determine_input_type(payload_dict)
45 if input_type is None:
46 return error_validation_response()
48 # create a new simulation in the database, not waiting for the job to finish
49 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid4()) + PlatformType.DIRECT.value
50 simulation = CelerySimulationModel(user_id=user.id,
51 job_id=job_id,
52 sim_type=payload_dict["sim_type"],
53 input_type=input_type,
54 title=payload_dict.get("title", ''))
55 add_object_to_db(simulation)
56 update_key = encode_simulation_auth_token(simulation.id)
57 logging.info("Simulation %d created and inserted into DB", simulation.id)
58 logging.debug("Update key set to %s", update_key)
60 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
61 # create tasks in the database in the default PENDING state
62 celery_ids = [str(uuid4()) for _ in range(payload_dict["ntasks"])]
63 for i in range(payload_dict["ntasks"]):
64 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i, celery_id=celery_ids[i])
65 add_object_to_db(task, make_commit=False)
66 make_commit_to_db()
68 # submit the asynchronous job to celery
69 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"],
70 celery_ids, payload_dict["sim_type"])
72 input_model = InputModel(simulation_id=simulation.id)
73 input_model.data = input_dict
74 add_object_to_db(input_model)
75 if simulation.update_state({"job_state": EntityState.PENDING.value}):
76 make_commit_to_db()
78 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id})
80 class APIParametersSchema(Schema):
81 """Class specifies API parameters for GET and DELETE request"""
83 job_id = fields.String()
85 @staticmethod
86 @requires_auth()
87 def get(user: UserModel):
88 """Method returning job status and results"""
89 # validate request parameters and handle errors
90 schema = JobsDirect.APIParametersSchema()
91 errors: dict[str, list[str]] = schema.validate(request.args)
92 if errors:
93 return yaptide_response(message="Wrong parameters", code=400, content=errors)
94 param_dict: dict = schema.load(request.args)
96 # get job_id from request parameters and check if user owns this job
97 job_id = param_dict['job_id']
98 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
99 if not is_owned:
100 return yaptide_response(message=error_message, code=res_code)
102 # find appropriate simulation in the database
103 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
105 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
107 job_tasks_status = [task.get_status_dict() for task in tasks]
109 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
110 return yaptide_response(message=f"Job state: {simulation.job_state}",
111 code=200,
112 content={
113 "job_state": simulation.job_state,
114 "job_tasks_status": job_tasks_status,
115 })
117 job_info = {"job_state": simulation.job_state}
118 status_counter = Counter([task["task_state"] for task in job_tasks_status])
119 if status_counter[EntityState.PENDING.value] == len(job_tasks_status):
120 job_info["job_state"] = EntityState.PENDING.value
121 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status):
122 job_info["job_state"] = EntityState.FAILED.value
123 elif status_counter[EntityState.RUNNING.value] > 0:
124 job_info["job_state"] = EntityState.RUNNING.value
126 # if simulation is not found, return error
127 update_simulation_state(simulation=simulation, update_dict=job_info)
129 job_info["job_tasks_status"] = job_tasks_status
131 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info)
133 @staticmethod
134 @requires_auth()
135 def delete(user: UserModel):
136 """Method canceling simulation and returning status of this action"""
137 schema = JobsDirect.APIParametersSchema()
138 errors: dict[str, list[str]] = schema.validate(request.args)
139 if errors:
140 return error_validation_response(content=errors)
141 params_dict: dict = schema.load(request.args)
143 job_id = params_dict['job_id']
145 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
146 if not is_owned:
147 return yaptide_response(message=error_message, code=res_code)
149 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
151 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
152 EntityState.UNKNOWN.value):
153 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
154 code=200,
155 content={
156 "job_state": simulation.job_state,
157 })
159 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
160 celery_ids = [
161 task.celery_id for task in tasks
162 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value, EntityState.UNKNOWN.value]
163 ]
165 # The merge_id is canceled first because merge task starts after run simulation tasks are finished/canceled.
166 # We don't want it to run accidentally.
167 celery_app.control.revoke(simulation.merge_id, terminate=True, signal="SIGINT")
168 celery_app.control.revoke(celery_ids, terminate=True, signal="SIGINT")
169 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
170 for task in tasks:
171 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value]:
172 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
174 terminate_unfinished_tasks.delay(simulation_id=simulation.id)
175 return yaptide_response(message="Cancelled sucessfully", code=200)
178class ResultsDirect(Resource):
179 """Class responsible for returning simulation results"""
181 class APIParametersSchema(Schema):
182 """Class specifies API parameters"""
184 job_id = fields.String()
186 @staticmethod
187 @requires_auth()
188 def get(user: UserModel):
189 """Method returning job status and results"""
190 schema = ResultsDirect.APIParametersSchema()
191 errors: dict[str, list[str]] = schema.validate(request.args)
192 if errors:
193 return yaptide_response(message="Wrong parameters", code=400, content=errors)
194 param_dict: dict = schema.load(request.args)
196 job_id = param_dict['job_id']
197 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
198 if not is_owned:
199 return yaptide_response(message=error_message, code=res_code)
201 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
203 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id)
204 if len(estimators) > 0:
205 logging.debug("Returning results from database")
206 result_estimators = []
207 for estimator in estimators:
208 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id)
209 estimator_dict = {
210 "metadata": estimator.data,
211 "name": estimator.name,
212 "pages": [page.data for page in pages]
213 }
214 result_estimators.append(estimator_dict)
215 return yaptide_response(message=f"Results for job: {job_id}",
216 code=200,
217 content={"estimators": result_estimators})
219 result: dict = get_job_results(job_id=job_id)
220 if "estimators" not in result:
221 logging.debug("Results for job %s are unavailable", job_id)
222 return yaptide_response(message="Results are unavailable", code=404, content=result)
224 for estimator_dict in result["estimators"]:
225 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id)
226 estimator.data = estimator_dict["metadata"]
227 add_object_to_db(estimator)
228 for page_dict in estimator_dict["pages"]:
229 page = PageModel(estimator_id=estimator.id,
230 page_number=int(page_dict["metadata"]["page_number"]),
231 page_dimension=int(page_dict['dimensions']),
232 page_name=str(page_dict["metadata"]["name"]))
233 page.data = page_dict
234 add_object_to_db(page, False)
235 make_commit_to_db()
237 logging.debug("Returning results from Celery")
238 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result)