Coverage for yaptide/routes/batch_routes.py: 23%
122 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
1import uuid
2from datetime import datetime
4from flask import request
5from flask_restful import Resource
6from marshmallow import Schema, fields
8from yaptide.batch.batch_methods import delete_job, get_job_status, submit_job
9from yaptide.persistence.db_methods import (add_object_to_db,
10 delete_object_from_db,
11 fetch_all_clusters,
12 fetch_batch_simulation_by_job_id,
13 fetch_batch_tasks_by_sim_id,
14 fetch_cluster_by_id,
15 make_commit_to_db,
16 update_simulation_state,
17 update_task_state)
18from yaptide.persistence.models import ( # skipcq: FLK-E101
19 BatchSimulationModel, BatchTaskModel, ClusterModel, InputModel,
20 KeycloakUserModel)
21from yaptide.routes.utils.decorators import requires_auth
22from yaptide.routes.utils.response_templates import (error_validation_response,
23 error_internal_response,
24 yaptide_response)
25from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
26from yaptide.utils.enums import EntityState, PlatformType
29class JobsBatch(Resource):
30 """Class responsible for jobs via direct slurm connection"""
32 @staticmethod
33 @requires_auth()
34 def post(user: KeycloakUserModel):
35 """Method handling running shieldhit with batch"""
36 if not isinstance(user, KeycloakUserModel):
37 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
39 payload_dict: dict = request.get_json(force=True)
40 if not payload_dict:
41 return yaptide_response(message="No JSON in body", code=400)
43 required_keys = {"sim_type", "ntasks", "input_type"}
45 if required_keys != required_keys.intersection(set(payload_dict.keys())):
46 diff = required_keys.difference(set(payload_dict.keys()))
47 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
49 input_type = determine_input_type(payload_dict)
51 if input_type is None:
52 return error_validation_response()
54 clusters = fetch_all_clusters()
55 if len(clusters) < 1:
56 return error_validation_response({"message": "No clusters are available"})
58 filtered_clusters: list[ClusterModel] = []
59 if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]:
60 cluster_name = payload_dict["batch_options"]["cluster_name"]
61 filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name]
62 cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0]
64 # create a new simulation in the database, not waiting for the job to finish
65 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value
66 simulation = BatchSimulationModel(user_id=user.id,
67 cluster_id=cluster.id,
68 job_id=job_id,
69 sim_type=payload_dict["sim_type"],
70 input_type=input_type,
71 title=payload_dict.get("title", ''))
72 update_key = str(uuid.uuid4())
73 simulation.set_update_key(update_key)
74 add_object_to_db(simulation)
76 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
78 result = submit_job(payload_dict=payload_dict, files_dict=input_dict["input_files"], user=user,
79 cluster=cluster, sim_id=simulation.id, update_key=update_key)
81 required_keys = {"job_dir", "array_id", "collect_id"}
82 if required_keys != required_keys.intersection(set(result.keys())):
83 delete_object_from_db(simulation)
84 return yaptide_response(
85 message="Job submission failed",
86 code=500,
87 content=result
88 )
89 simulation.job_dir = result.pop("job_dir", None)
90 simulation.array_id = result.pop("array_id", None)
91 simulation.collect_id = result.pop("collect_id", None)
92 result["job_id"] = simulation.job_id
94 for i in range(payload_dict["ntasks"]):
95 task = BatchTaskModel(simulation_id=simulation.id, task_id=str(i+1))
96 add_object_to_db(task, False)
98 input_model = InputModel(simulation_id=simulation.id)
99 input_model.data = input_dict
100 add_object_to_db(input_model)
101 if simulation.update_state({"job_state": EntityState.PENDING.value}):
102 make_commit_to_db()
104 return yaptide_response(
105 message="Job submitted",
106 code=202,
107 content=result
108 )
110 class APIParametersSchema(Schema):
111 """Class specifies API parameters"""
113 job_id = fields.String()
115 @staticmethod
116 @requires_auth()
117 def get(user: KeycloakUserModel):
118 """Method geting job's result"""
119 if not isinstance(user, KeycloakUserModel):
120 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
122 schema = JobsBatch.APIParametersSchema()
123 errors: dict[str, list[str]] = schema.validate(request.args)
124 if errors:
125 return error_validation_response(content=errors)
126 params_dict: dict = schema.load(request.args)
128 job_id: str = params_dict["job_id"]
130 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
131 if not is_owned:
132 return yaptide_response(message=error_message, code=res_code)
133 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
135 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
137 job_tasks_status = [task.get_status_dict() for task in tasks]
139 if simulation.job_state in (EntityState.COMPLETED.value,
140 EntityState.FAILED.value):
141 return yaptide_response(message=f"Job state: {simulation.job_state}",
142 code=200,
143 content={
144 "job_state": simulation.job_state,
145 "job_tasks_status": job_tasks_status,
146 })
148 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
150 job_info = get_job_status(simulation=simulation, user=user, cluster=cluster)
151 update_simulation_state(simulation=simulation, update_dict=job_info)
153 job_info.pop("end_time", None)
154 job_info["job_tasks_status"] = job_tasks_status
156 return yaptide_response(
157 message="",
158 code=200,
159 content=job_info
160 )
162 @staticmethod
163 @requires_auth()
164 def delete(user: KeycloakUserModel):
165 """Method canceling job"""
166 if not isinstance(user, KeycloakUserModel):
167 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
169 schema = JobsBatch.APIParametersSchema()
170 errors: dict[str, list[str]] = schema.validate(request.args)
171 if errors:
172 return error_validation_response(content=errors)
173 params_dict: dict = schema.load(request.args)
175 job_id: str = params_dict["job_id"]
177 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
178 if not is_owned:
179 return yaptide_response(message=error_message, code=res_code)
181 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
183 if simulation.job_state in (EntityState.COMPLETED.value,
184 EntityState.FAILED.value,
185 EntityState.CANCELED.value,
186 EntityState.UNKNOWN.value):
187 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
188 code=200,
189 content={
190 "job_state": simulation.job_state,
191 })
193 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
195 result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster)
196 if status_code != 200:
197 return error_internal_response(content=result)
199 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
201 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
203 for task in tasks:
204 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
206 return yaptide_response(
207 message="",
208 code=status_code,
209 content=result
210 )
213class Clusters(Resource):
214 """Class responsible for returning user's available clusters"""
216 @staticmethod
217 @requires_auth()
218 def get(user: KeycloakUserModel):
219 """Method returning clusters"""
220 if not isinstance(user, KeycloakUserModel):
221 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
223 clusters = fetch_all_clusters()
225 result = {
226 'clusters': [
227 {
228 'cluster_name': cluster.cluster_name
229 }
230 for cluster in clusters
231 ]
232 }
233 return yaptide_response(message='Available clusters', code=200, content=result)