Coverage for yaptide/routes/batch_routes.py: 25%
115 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-12 06:23 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-08-12 06:23 +0000
1import uuid
2from datetime import datetime
4from flask import request
5from flask_restful import Resource
6from marshmallow import Schema, fields
8from yaptide.batch.batch_methods import delete_job, get_job_status, submit_job
9from yaptide.persistence.db_methods import (add_object_to_db, fetch_all_clusters, fetch_batch_simulation_by_job_id,
10 fetch_batch_tasks_by_sim_id, fetch_cluster_by_id, make_commit_to_db,
11 update_simulation_state, update_task_state)
12from yaptide.persistence.models import ( # skipcq: FLK-E101
13 BatchSimulationModel, BatchTaskModel, ClusterModel, InputModel, KeycloakUserModel)
14from yaptide.routes.utils.tokens import encode_simulation_auth_token
15from yaptide.routes.utils.decorators import requires_auth
16from yaptide.routes.utils.response_templates import (error_validation_response, error_internal_response,
17 yaptide_response)
18from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict
19from yaptide.utils.enums import EntityState, PlatformType
22class JobsBatch(Resource):
23 """Class responsible for jobs via direct slurm connection"""
25 @staticmethod
26 @requires_auth()
27 def post(user: KeycloakUserModel):
28 """Method handling running shieldhit with batch"""
29 if not isinstance(user, KeycloakUserModel):
30 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
32 payload_dict: dict = request.get_json(force=True)
33 if not payload_dict:
34 return yaptide_response(message="No JSON in body", code=400)
36 required_keys = {"sim_type", "ntasks", "input_type"}
38 if required_keys != required_keys.intersection(set(payload_dict.keys())):
39 diff = required_keys.difference(set(payload_dict.keys()))
40 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
42 input_type = determine_input_type(payload_dict)
44 if input_type is None:
45 return error_validation_response()
47 clusters = fetch_all_clusters()
48 if len(clusters) < 1:
49 return error_validation_response({"message": "No clusters are available"})
51 filtered_clusters: list[ClusterModel] = []
52 if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]:
53 cluster_name = payload_dict["batch_options"]["cluster_name"]
54 filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name]
55 cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0]
57 # create a new simulation in the database, not waiting for the job to finish
58 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value
59 # skipcq: PYL-E1123
60 simulation = BatchSimulationModel(user_id=user.id,
61 cluster_id=cluster.id,
62 job_id=job_id,
63 sim_type=payload_dict["sim_type"],
64 input_type=input_type,
65 title=payload_dict.get("title", ''))
66 add_object_to_db(simulation)
67 update_key = encode_simulation_auth_token(simulation.id)
69 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
71 submit_job.delay(payload_dict=payload_dict,
72 files_dict=input_dict["input_files"],
73 userId=user.id,
74 clusterId=cluster.id,
75 sim_id=simulation.id,
76 update_key=update_key)
78 requested_primaries = input_dict["number_of_all_primaries"] // payload_dict["ntasks"]
79 for i in range(payload_dict["ntasks"]):
80 task = BatchTaskModel(simulation_id=simulation.id,
81 task_id=str(i + 1),
82 requested_primaries=requested_primaries,)
83 add_object_to_db(task, False)
85 input_model = InputModel(simulation_id=simulation.id)
86 input_model.data = input_dict
87 add_object_to_db(input_model)
88 if simulation.update_state({"job_state": EntityState.PENDING.value}):
89 make_commit_to_db()
91 return yaptide_response(message="Job waiting for submission", code=202, content={'job_id': simulation.job_id})
93 class APIParametersSchema(Schema):
94 """Class specifies API parameters"""
96 job_id = fields.String()
98 @staticmethod
99 @requires_auth()
100 def get(user: KeycloakUserModel):
101 """Method geting job's result"""
102 if not isinstance(user, KeycloakUserModel):
103 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
105 schema = JobsBatch.APIParametersSchema()
106 errors: dict[str, list[str]] = schema.validate(request.args)
107 if errors:
108 return error_validation_response(content=errors)
109 params_dict: dict = schema.load(request.args)
111 job_id: str = params_dict["job_id"]
113 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
114 if not is_owned:
115 return yaptide_response(message=error_message, code=res_code)
116 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
118 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
120 job_tasks_status = [task.get_status_dict() for task in tasks]
122 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
123 return yaptide_response(message=f"Job state: {simulation.job_state}",
124 code=200,
125 content={
126 "job_state": simulation.job_state,
127 "job_tasks_status": job_tasks_status,
128 })
130 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
132 job_info = get_job_status(simulation=simulation, user=user, cluster=cluster)
133 update_simulation_state(simulation=simulation, update_dict=job_info)
135 job_info.pop("end_time", None)
136 job_info["job_tasks_status"] = job_tasks_status
138 return yaptide_response(message="", code=200, content=job_info)
140 @staticmethod
141 @requires_auth()
142 def delete(user: KeycloakUserModel):
143 """Method canceling job"""
144 if not isinstance(user, KeycloakUserModel):
145 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
147 schema = JobsBatch.APIParametersSchema()
148 errors: dict[str, list[str]] = schema.validate(request.args)
149 if errors:
150 return error_validation_response(content=errors)
151 params_dict: dict = schema.load(request.args)
153 job_id: str = params_dict["job_id"]
155 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
156 if not is_owned:
157 return yaptide_response(message=error_message, code=res_code)
159 simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
161 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
162 EntityState.UNKNOWN.value):
163 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
164 code=200,
165 content={
166 "job_state": simulation.job_state,
167 })
169 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
171 result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster)
172 if status_code != 200:
173 return error_internal_response(content=result)
175 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
177 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
179 for task in tasks:
180 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
182 return yaptide_response(message="", code=status_code, content=result)
185class Clusters(Resource):
186 """Class responsible for returning user's available clusters"""
188 @staticmethod
189 @requires_auth()
190 def get(user: KeycloakUserModel):
191 """Method returning clusters"""
192 if not isinstance(user, KeycloakUserModel):
193 return yaptide_response(message="User is not allowed to use this endpoint", code=403)
195 clusters = fetch_all_clusters()
197 result = {'clusters': [{'cluster_name': cluster.cluster_name} for cluster in clusters]}
198 return yaptide_response(message='Available clusters', code=200, content=result)