Coverage for yaptide/routes/celery_routes.py: 42%
146 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-12 06:23 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-12 06:23 +0000
1import logging
2from collections import Counter
3from datetime import datetime
5from flask import request
6from flask_restful import Resource
7from marshmallow import Schema, fields
8from uuid import uuid4
10from yaptide.celery.simulation_worker import celery_app
11from yaptide.celery.utils.manage_tasks import (get_job_results, run_job)
12from yaptide.persistence.db_methods import (add_object_to_db, fetch_celery_simulation_by_job_id,
13 fetch_celery_tasks_by_sim_id, fetch_estimators_by_sim_id,
14 fetch_pages_by_estimator_id, make_commit_to_db, update_simulation_state,
15 update_task_state)
16from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, EstimatorModel, InputModel, PageModel,
17 UserModel)
18from yaptide.routes.utils.decorators import requires_auth
19from yaptide.routes.utils.response_templates import (error_validation_response, yaptide_response)
20from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
21from yaptide.routes.utils.tokens import encode_simulation_auth_token
22from yaptide.utils.enums import EntityState, PlatformType
23from yaptide.utils.helper_tasks import terminate_unfinished_tasks
26class JobsDirect(Resource):
27 """Class responsible for simulations run directly with celery"""
29 @staticmethod
30 @requires_auth()
31 def post(user: UserModel):
32 """Submit simulation job to celery"""
33 payload_dict: dict = request.get_json(force=True)
34 if not payload_dict:
35 return yaptide_response(message="No JSON in body", code=400)
37 required_keys = {"sim_type", "ntasks", "input_type"}
39 if required_keys != required_keys.intersection(set(payload_dict.keys())):
40 diff = required_keys.difference(set(payload_dict.keys()))
41 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
43 input_type = determine_input_type(payload_dict)
45 if input_type is None:
46 return error_validation_response()
48 # create a new simulation in the database, not waiting for the job to finish
49 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid4()) + PlatformType.DIRECT.value
50 simulation = CelerySimulationModel(user_id=user.id,
51 job_id=job_id,
52 sim_type=payload_dict["sim_type"],
53 input_type=input_type,
54 title=payload_dict.get("title", ''))
55 add_object_to_db(simulation)
56 update_key = encode_simulation_auth_token(simulation.id)
57 logging.info("Simulation %d created and inserted into DB", simulation.id)
58 logging.debug("Update key set to %s", update_key)
60 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
61 # create tasks in the database in the default PENDING state
62 celery_ids = [str(uuid4()) for _ in range(payload_dict["ntasks"])]
63 requested_primaries = input_dict["number_of_all_primaries"] // payload_dict["ntasks"]
64 for i in range(payload_dict["ntasks"]):
65 task = CeleryTaskModel(simulation_id=simulation.id,
66 task_id=i, celery_id=celery_ids[i],
67 requested_primaries=requested_primaries)
68 add_object_to_db(task, make_commit=False)
69 make_commit_to_db()
71 # submit the asynchronous job to celery
72 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"],
73 celery_ids, payload_dict["sim_type"])
75 input_model = InputModel(simulation_id=simulation.id)
76 input_model.data = input_dict
77 add_object_to_db(input_model)
78 if simulation.update_state({"job_state": EntityState.PENDING.value}):
79 make_commit_to_db()
81 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id})
83 class APIParametersSchema(Schema):
84 """Class specifies API parameters for GET and DELETE request"""
86 job_id = fields.String()
88 @staticmethod
89 @requires_auth()
90 def get(user: UserModel):
91 """Method returning job status and results"""
92 # validate request parameters and handle errors
93 schema = JobsDirect.APIParametersSchema()
94 errors: dict[str, list[str]] = schema.validate(request.args)
95 if errors:
96 return yaptide_response(message="Wrong parameters", code=400, content=errors)
97 param_dict: dict = schema.load(request.args)
99 # get job_id from request parameters and check if user owns this job
100 job_id = param_dict['job_id']
101 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
102 if not is_owned:
103 return yaptide_response(message=error_message, code=res_code)
105 # find appropriate simulation in the database
106 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
108 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
110 job_tasks_status = [task.get_status_dict() for task in tasks]
112 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
113 return yaptide_response(message=f"Job state: {simulation.job_state}",
114 code=200,
115 content={
116 "job_state": simulation.job_state,
117 "job_tasks_status": job_tasks_status,
118 })
120 job_info = {"job_state": simulation.job_state}
121 status_counter = Counter([task["task_state"] for task in job_tasks_status])
122 if status_counter[EntityState.PENDING.value] == len(job_tasks_status):
123 job_info["job_state"] = EntityState.PENDING.value
124 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status):
125 job_info["job_state"] = EntityState.FAILED.value
126 elif status_counter[EntityState.RUNNING.value] > 0:
127 job_info["job_state"] = EntityState.RUNNING.value
129 # if simulation is not found, return error
130 update_simulation_state(simulation=simulation, update_dict=job_info)
132 job_info["job_tasks_status"] = job_tasks_status
134 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info)
136 @staticmethod
137 @requires_auth()
138 def delete(user: UserModel):
139 """Method canceling simulation and returning status of this action"""
140 schema = JobsDirect.APIParametersSchema()
141 errors: dict[str, list[str]] = schema.validate(request.args)
142 if errors:
143 return error_validation_response(content=errors)
144 params_dict: dict = schema.load(request.args)
146 job_id = params_dict['job_id']
148 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
149 if not is_owned:
150 return yaptide_response(message=error_message, code=res_code)
152 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
154 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
155 EntityState.UNKNOWN.value):
156 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
157 code=200,
158 content={
159 "job_state": simulation.job_state,
160 })
162 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id)
163 celery_ids = [
164 task.celery_id for task in tasks
165 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value, EntityState.UNKNOWN.value]
166 ]
168 # The merge_id is canceled first because merge task starts after run simulation tasks are finished/canceled.
169 # We don't want it to run accidentally.
170 celery_app.control.revoke(simulation.merge_id, terminate=True, signal="SIGINT")
171 celery_app.control.revoke(celery_ids, terminate=True, signal="SIGINT")
172 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
173 for task in tasks:
174 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value]:
175 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
177 terminate_unfinished_tasks.delay(simulation_id=simulation.id)
178 return yaptide_response(message="Cancelled sucessfully", code=200)
181class ResultsDirect(Resource):
182 """Class responsible for returning simulation results"""
184 class APIParametersSchema(Schema):
185 """Class specifies API parameters"""
187 job_id = fields.String()
189 @staticmethod
190 @requires_auth()
191 def get(user: UserModel):
192 """Method returning job status and results"""
193 schema = ResultsDirect.APIParametersSchema()
194 errors: dict[str, list[str]] = schema.validate(request.args)
195 if errors:
196 return yaptide_response(message="Wrong parameters", code=400, content=errors)
197 param_dict: dict = schema.load(request.args)
199 job_id = param_dict['job_id']
200 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
201 if not is_owned:
202 return yaptide_response(message=error_message, code=res_code)
204 simulation = fetch_celery_simulation_by_job_id(job_id=job_id)
206 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id)
207 if len(estimators) > 0:
208 logging.debug("Returning results from database")
209 result_estimators = []
210 for estimator in estimators:
211 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id)
212 estimator_dict = {
213 "metadata": estimator.data,
214 "name": estimator.name,
215 "pages": [page.data for page in pages]
216 }
217 result_estimators.append(estimator_dict)
218 return yaptide_response(message=f"Results for job: {job_id}",
219 code=200,
220 content={"estimators": result_estimators})
222 result: dict = get_job_results(job_id=job_id)
223 if "estimators" not in result:
224 logging.debug("Results for job %s are unavailable", job_id)
225 return yaptide_response(message="Results are unavailable", code=404, content=result)
227 for estimator_dict in result["estimators"]:
228 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id)
229 estimator.data = estimator_dict["metadata"]
230 add_object_to_db(estimator)
231 for page_dict in estimator_dict["pages"]:
232 page = PageModel(estimator_id=estimator.id,
233 page_number=int(page_dict["metadata"]["page_number"]),
234 page_dimension=int(page_dict['dimensions']),
235 page_name=str(page_dict["metadata"]["name"]))
236 page.data = page_dict
237 add_object_to_db(page, False)
238 make_commit_to_db()
240 logging.debug("Returning results from Celery")
241 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result)