22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179 | class JobsBatch(Resource):
"""Class responsible for jobs via direct slurm connection"""
@staticmethod
@requires_auth()
def post(user: KeycloakUserModel):
"""Method handling running shieldhit with batch"""
if not isinstance(user, KeycloakUserModel):
return yaptide_response(message="User is not allowed to use this endpoint", code=403)
payload_dict: dict = request.get_json(force=True)
if not payload_dict:
return yaptide_response(message="No JSON in body", code=400)
required_keys = {"sim_type", "ntasks", "input_type"}
if required_keys != required_keys.intersection(set(payload_dict.keys())):
diff = required_keys.difference(set(payload_dict.keys()))
return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400)
input_type = determine_input_type(payload_dict)
if input_type is None:
return error_validation_response()
clusters = fetch_all_clusters()
if len(clusters) < 1:
return error_validation_response({"message": "No clusters are available"})
filtered_clusters: list[ClusterModel] = []
if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]:
cluster_name = payload_dict["batch_options"]["cluster_name"]
filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name]
cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0]
# create a new simulation in the database, not waiting for the job to finish
job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value
# skipcq: PYL-E1123
simulation = BatchSimulationModel(user_id=user.id,
cluster_id=cluster.id,
job_id=job_id,
sim_type=payload_dict["sim_type"],
input_type=input_type,
title=payload_dict.get("title", ''))
add_object_to_db(simulation)
update_key = encode_simulation_auth_token(simulation.id)
input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type)
submit_job.delay(payload_dict=payload_dict,
files_dict=input_dict["input_files"],
userId=user.id,
clusterId=cluster.id,
sim_id=simulation.id,
update_key=update_key)
for i in range(payload_dict["ntasks"]):
task = BatchTaskModel(simulation_id=simulation.id, task_id=str(i + 1))
add_object_to_db(task, False)
input_model = InputModel(simulation_id=simulation.id)
input_model.data = input_dict
add_object_to_db(input_model)
if simulation.update_state({"job_state": EntityState.PENDING.value}):
make_commit_to_db()
return yaptide_response(message="Job waiting for submission", code=202, content={'job_id': simulation.job_id})
class APIParametersSchema(Schema):
"""Class specifies API parameters"""
job_id = fields.String()
@staticmethod
@requires_auth()
def get(user: KeycloakUserModel):
"""Method geting job's result"""
if not isinstance(user, KeycloakUserModel):
return yaptide_response(message="User is not allowed to use this endpoint", code=403)
schema = JobsBatch.APIParametersSchema()
errors: dict[str, list[str]] = schema.validate(request.args)
if errors:
return error_validation_response(content=errors)
params_dict: dict = schema.load(request.args)
job_id: str = params_dict["job_id"]
is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
if not is_owned:
return yaptide_response(message=error_message, code=res_code)
simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
job_tasks_status = [task.get_status_dict() for task in tasks]
if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value):
return yaptide_response(message=f"Job state: {simulation.job_state}",
code=200,
content={
"job_state": simulation.job_state,
"job_tasks_status": job_tasks_status,
})
cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
job_info = get_job_status(simulation=simulation, user=user, cluster=cluster)
update_simulation_state(simulation=simulation, update_dict=job_info)
job_info.pop("end_time", None)
job_info["job_tasks_status"] = job_tasks_status
return yaptide_response(message="", code=200, content=job_info)
@staticmethod
@requires_auth()
def delete(user: KeycloakUserModel):
"""Method canceling job"""
if not isinstance(user, KeycloakUserModel):
return yaptide_response(message="User is not allowed to use this endpoint", code=403)
schema = JobsBatch.APIParametersSchema()
errors: dict[str, list[str]] = schema.validate(request.args)
if errors:
return error_validation_response(content=errors)
params_dict: dict = schema.load(request.args)
job_id: str = params_dict["job_id"]
is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user)
if not is_owned:
return yaptide_response(message=error_message, code=res_code)
simulation = fetch_batch_simulation_by_job_id(job_id=job_id)
if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value,
EntityState.UNKNOWN.value):
return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state",
code=200,
content={
"job_state": simulation.job_state,
})
cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id)
result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster)
if status_code != 200:
return error_internal_response(content=result)
update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value})
tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id)
for task in tasks:
update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value})
return yaptide_response(message="", code=status_code, content=result)
|