Coverage for yaptide/routes/batch_routes.py: 23%

122 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import uuid 

2from datetime import datetime 

3 

4from flask import request 

5from flask_restful import Resource 

6from marshmallow import Schema, fields 

7 

8from yaptide.batch.batch_methods import delete_job, get_job_status, submit_job 

9from yaptide.persistence.db_methods import (add_object_to_db, 

10 delete_object_from_db, 

11 fetch_all_clusters, 

12 fetch_batch_simulation_by_job_id, 

13 fetch_batch_tasks_by_sim_id, 

14 fetch_cluster_by_id, 

15 make_commit_to_db, 

16 update_simulation_state, 

17 update_task_state) 

18from yaptide.persistence.models import ( # skipcq: FLK-E101 

19 BatchSimulationModel, BatchTaskModel, ClusterModel, InputModel, 

20 KeycloakUserModel) 

21from yaptide.routes.utils.decorators import requires_auth 

22from yaptide.routes.utils.response_templates import (error_validation_response, 

23 error_internal_response, 

24 yaptide_response) 

25from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict 

26from yaptide.utils.enums import EntityState, PlatformType 

27 

28 

29class JobsBatch(Resource): 

30 """Class responsible for jobs via direct slurm connection""" 

31 

32 @staticmethod 

33 @requires_auth() 

34 def post(user: KeycloakUserModel): 

35 """Method handling running shieldhit with batch""" 

36 if not isinstance(user, KeycloakUserModel): 

37 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

38 

39 payload_dict: dict = request.get_json(force=True) 

40 if not payload_dict: 

41 return yaptide_response(message="No JSON in body", code=400) 

42 

43 required_keys = {"sim_type", "ntasks", "input_type"} 

44 

45 if required_keys != required_keys.intersection(set(payload_dict.keys())): 

46 diff = required_keys.difference(set(payload_dict.keys())) 

47 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400) 

48 

49 input_type = determine_input_type(payload_dict) 

50 

51 if input_type is None: 

52 return error_validation_response() 

53 

54 clusters = fetch_all_clusters() 

55 if len(clusters) < 1: 

56 return error_validation_response({"message": "No clusters are available"}) 

57 

58 filtered_clusters: list[ClusterModel] = [] 

59 if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]: 

60 cluster_name = payload_dict["batch_options"]["cluster_name"] 

61 filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name] 

62 cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0] 

63 

64 # create a new simulation in the database, not waiting for the job to finish 

65 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value 

66 simulation = BatchSimulationModel(user_id=user.id, 

67 cluster_id=cluster.id, 

68 job_id=job_id, 

69 sim_type=payload_dict["sim_type"], 

70 input_type=input_type, 

71 title=payload_dict.get("title", '')) 

72 update_key = str(uuid.uuid4()) 

73 simulation.set_update_key(update_key) 

74 add_object_to_db(simulation) 

75 

76 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type) 

77 

78 result = submit_job(payload_dict=payload_dict, files_dict=input_dict["input_files"], user=user, 

79 cluster=cluster, sim_id=simulation.id, update_key=update_key) 

80 

81 required_keys = {"job_dir", "array_id", "collect_id"} 

82 if required_keys != required_keys.intersection(set(result.keys())): 

83 delete_object_from_db(simulation) 

84 return yaptide_response( 

85 message="Job submission failed", 

86 code=500, 

87 content=result 

88 ) 

89 simulation.job_dir = result.pop("job_dir", None) 

90 simulation.array_id = result.pop("array_id", None) 

91 simulation.collect_id = result.pop("collect_id", None) 

92 result["job_id"] = simulation.job_id 

93 

94 for i in range(payload_dict["ntasks"]): 

95 task = BatchTaskModel(simulation_id=simulation.id, task_id=str(i+1)) 

96 add_object_to_db(task, False) 

97 

98 input_model = InputModel(simulation_id=simulation.id) 

99 input_model.data = input_dict 

100 add_object_to_db(input_model) 

101 if simulation.update_state({"job_state": EntityState.PENDING.value}): 

102 make_commit_to_db() 

103 

104 return yaptide_response( 

105 message="Job submitted", 

106 code=202, 

107 content=result 

108 ) 

109 

110 class APIParametersSchema(Schema): 

111 """Class specifies API parameters""" 

112 

113 job_id = fields.String() 

114 

115 @staticmethod 

116 @requires_auth() 

117 def get(user: KeycloakUserModel): 

118 """Method geting job's result""" 

119 if not isinstance(user, KeycloakUserModel): 

120 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

121 

122 schema = JobsBatch.APIParametersSchema() 

123 errors: dict[str, list[str]] = schema.validate(request.args) 

124 if errors: 

125 return error_validation_response(content=errors) 

126 params_dict: dict = schema.load(request.args) 

127 

128 job_id: str = params_dict["job_id"] 

129 

130 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

131 if not is_owned: 

132 return yaptide_response(message=error_message, code=res_code) 

133 simulation = fetch_batch_simulation_by_job_id(job_id=job_id) 

134 

135 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id) 

136 

137 job_tasks_status = [task.get_status_dict() for task in tasks] 

138 

139 if simulation.job_state in (EntityState.COMPLETED.value, 

140 EntityState.FAILED.value): 

141 return yaptide_response(message=f"Job state: {simulation.job_state}", 

142 code=200, 

143 content={ 

144 "job_state": simulation.job_state, 

145 "job_tasks_status": job_tasks_status, 

146 }) 

147 

148 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id) 

149 

150 job_info = get_job_status(simulation=simulation, user=user, cluster=cluster) 

151 update_simulation_state(simulation=simulation, update_dict=job_info) 

152 

153 job_info.pop("end_time", None) 

154 job_info["job_tasks_status"] = job_tasks_status 

155 

156 return yaptide_response( 

157 message="", 

158 code=200, 

159 content=job_info 

160 ) 

161 

162 @staticmethod 

163 @requires_auth() 

164 def delete(user: KeycloakUserModel): 

165 """Method canceling job""" 

166 if not isinstance(user, KeycloakUserModel): 

167 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

168 

169 schema = JobsBatch.APIParametersSchema() 

170 errors: dict[str, list[str]] = schema.validate(request.args) 

171 if errors: 

172 return error_validation_response(content=errors) 

173 params_dict: dict = schema.load(request.args) 

174 

175 job_id: str = params_dict["job_id"] 

176 

177 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

178 if not is_owned: 

179 return yaptide_response(message=error_message, code=res_code) 

180 

181 simulation = fetch_batch_simulation_by_job_id(job_id=job_id) 

182 

183 if simulation.job_state in (EntityState.COMPLETED.value, 

184 EntityState.FAILED.value, 

185 EntityState.CANCELED.value, 

186 EntityState.UNKNOWN.value): 

187 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state", 

188 code=200, 

189 content={ 

190 "job_state": simulation.job_state, 

191 }) 

192 

193 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id) 

194 

195 result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster) 

196 if status_code != 200: 

197 return error_internal_response(content=result) 

198 

199 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value}) 

200 

201 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id) 

202 

203 for task in tasks: 

204 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value}) 

205 

206 return yaptide_response( 

207 message="", 

208 code=status_code, 

209 content=result 

210 ) 

211 

212 

213class Clusters(Resource): 

214 """Class responsible for returning user's available clusters""" 

215 

216 @staticmethod 

217 @requires_auth() 

218 def get(user: KeycloakUserModel): 

219 """Method returning clusters""" 

220 if not isinstance(user, KeycloakUserModel): 

221 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

222 

223 clusters = fetch_all_clusters() 

224 

225 result = { 

226 'clusters': [ 

227 { 

228 'cluster_name': cluster.cluster_name 

229 } 

230 for cluster in clusters 

231 ] 

232 } 

233 return yaptide_response(message='Available clusters', code=200, content=result)