Coverage for yaptide/routes/batch_routes.py: 25%
114 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import uuid
2from datetime import datetime
4from flask import request
5from flask_restful import Resource
6from marshmallow import Schema, fields
8from yaptide.batch.batch_methods import delete_job, get_job_status, submit_job
9from yaptide.persistence.db_methods import (add_object_to_db, fetch_all_clusters, fetch_batch_simulation_by_job_id,
10 fetch_batch_tasks_by_sim_id, fetch_cluster_by_id, make_commit_to_db,
11 update_simulation_state, update_task_state)
12from yaptide.persistence.models import ( # skipcq: FLK-E101
13 BatchSimulationModel, BatchTaskModel, ClusterModel, InputModel, KeycloakUserModel)
14from yaptide.routes.utils.tokens import encode_simulation_auth_token
15from yaptide.routes.utils.decorators import requires_auth
16from yaptide.routes.utils.response_templates import (error_validation_response, error_internal_response,
17 yaptide_response)
18from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
19from yaptide.utils.enums import EntityState, PlatformType
22class JobsBatch(Resource):
23 """Class responsible for jobs via direct slurm connection"""
25 @staticmethod
26 @requires_auth()
27 def post(user: KeycloakUserModel):
28 """Method handling running shieldhit with batch"""
29 if not isinstance(user, KeycloakUserModel):
30 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
32 payload_dict: dict = request.get_json(force=True)
33 if not payload_dict:
34 return yaptide_response(message="No JSON in body", code=400)
36 required_keys = {"sim_type", "ntasks", "input_type"}
38 if required_keys != required_keys.intersection(set(payload_dict.keys())):
39 diff = required_keys.difference(set(payload_dict.keys()))
40 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
42 input_type = determine_input_type(payload_dict)
44 if input_type is None:
45 return error_validation_response()
47 clusters = fetch_all_clusters()
48 if len(clusters) < 1:
49 return error_validation_response({"message": "No clusters are available"})
51 filtered_clusters: list[ClusterModel] = []
52 if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]:
53 cluster_name = payload_dict["batch_options"]["cluster_name"]
54 filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name]
55 cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0]
57 # create a new simulation in the database, not waiting for the job to finish
58 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value
59 # skipcq: PYL-E1123
60 simulation = BatchSimulationModel(user_id=user.id,
61 cluster_id=cluster.id,
62 job_id=job_id,
63 sim_type=payload_dict["sim_type"],
64 input_type=input_type,
65 title=payload_dict.get("title", ''))
66 add_object_to_db(simulation)
67 update_key = encode_simulation_auth_token(simulation.id)
69 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
71 submit_job.delay(payload_dict=payload_dict,
72 files_dict=input_dict["input_files"],
73 userId=user.id,
74 clusterId=cluster.id,
75 sim_id=simulation.id,
76 update_key=update_key)
78 for i in range(payload_dict["ntasks"]):
79 task = BatchTaskModel(simulation_id=simulation.id, task_id=str(i + 1))
80 add_object_to_db(task, False)
82 input_model = InputModel(simulation_id=simulation.id)
83 input_model.data = input_dict
84 add_object_to_db(input_model)
85 if simulation.update_state({"job_state": EntityState.PENDING.value}):
86 make_commit_to_db()
88 return yaptide_response(message="Job waiting for submission", code=202, content={'job_id': simulation.job_id})
90 class APIParametersSchema(Schema):
91 """Class specifies API parameters"""
93 job_id = fields.String()
95 @staticmethod
96 @requires_auth()
97 def get(user: KeycloakUserModel):
98 """Method geting job's result"""
99 if not isinstance(user, KeycloakUserModel):
100 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
102 schema = JobsBatch.APIParametersSchema()
103 errors: dict[str, list[str]] = schema.validate(request.args)
104 if errors:
105 return error_validation_response(content=errors)
106 params_dict: dict = schema.load(request.args)
108 job_id: str = params_dict["job_id"]
110 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
111 if not is_owned:
112 return yaptide_response(message=error_message, code=res_code)
113 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
115 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
117 job_tasks_status = [task.get_status_dict() for task in tasks]
119 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
120 return yaptide_response(message=f"Job state: {simulation.job_state}",
121 code=200,
122 content={
123 "job_state": simulation.job_state,
124 "job_tasks_status": job_tasks_status,
125 })
127 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
129 job_info = get_job_status(simulation=simulation, user=user, cluster=cluster)
130 update_simulation_state(simulation=simulation, update_dict=job_info)
132 job_info.pop("end_time", None)
133 job_info["job_tasks_status"] = job_tasks_status
135 return yaptide_response(message="", code=200, content=job_info)
137 @staticmethod
138 @requires_auth()
139 def delete(user: KeycloakUserModel):
140 """Method canceling job"""
141 if not isinstance(user, KeycloakUserModel):
142 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
144 schema = JobsBatch.APIParametersSchema()
145 errors: dict[str, list[str]] = schema.validate(request.args)
146 if errors:
147 return error_validation_response(content=errors)
148 params_dict: dict = schema.load(request.args)
150 job_id: str = params_dict["job_id"]
152 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
153 if not is_owned:
154 return yaptide_response(message=error_message, code=res_code)
156 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
158 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
159 EntityState.UNKNOWN.value):
160 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
161 code=200,
162 content={
163 "job_state": simulation.job_state,
164 })
166 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
168 result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster)
169 if status_code != 200:
170 return error_internal_response(content=result)
172 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
174 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
176 for task in tasks:
177 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
179 return yaptide_response(message="", code=status_code, content=result)
182class Clusters(Resource):
183 """Class responsible for returning user's available clusters"""
185 @staticmethod
186 @requires_auth()
187 def get(user: KeycloakUserModel):
188 """Method returning clusters"""
189 if not isinstance(user, KeycloakUserModel):
190 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
192 clusters = fetch_all_clusters()
194 result = {'clusters': [{'cluster_name': cluster.cluster_name} for cluster in clusters]}
195 return yaptide_response(message='Available clusters', code=200, content=result)