Coverage for yaptide/routes/batch_routes.py: 25%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import uuid 

2from datetime import datetime 

3 

4from flask import request 

5from flask_restful import Resource 

6from marshmallow import Schema, fields 

7 

8from yaptide.batch.batch_methods import delete_job, get_job_status, submit_job 

9from yaptide.persistence.db_methods import (add_object_to_db, fetch_all_clusters, fetch_batch_simulation_by_job_id, 

10 fetch_batch_tasks_by_sim_id, fetch_cluster_by_id, make_commit_to_db, 

11 update_simulation_state, update_task_state) 

12from yaptide.persistence.models import ( # skipcq: FLK-E101 

13 BatchSimulationModel, BatchTaskModel, ClusterModel, InputModel, KeycloakUserModel) 

14from yaptide.routes.utils.tokens import encode_simulation_auth_token 

15from yaptide.routes.utils.decorators import requires_auth 

16from yaptide.routes.utils.response_templates import (error_validation_response, error_internal_response, 

17 yaptide_response) 

18from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict 

19from yaptide.utils.enums import EntityState, PlatformType 

20 

21 

22class JobsBatch(Resource): 

23 """Class responsible for jobs via direct slurm connection""" 

24 

25 @staticmethod 

26 @requires_auth() 

27 def post(user: KeycloakUserModel): 

28 """Method handling running shieldhit with batch""" 

29 if not isinstance(user, KeycloakUserModel): 

30 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

31 

32 payload_dict: dict = request.get_json(force=True) 

33 if not payload_dict: 

34 return yaptide_response(message="No JSON in body", code=400) 

35 

36 required_keys = {"sim_type", "ntasks", "input_type"} 

37 

38 if required_keys != required_keys.intersection(set(payload_dict.keys())): 

39 diff = required_keys.difference(set(payload_dict.keys())) 

40 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400) 

41 

42 input_type = determine_input_type(payload_dict) 

43 

44 if input_type is None: 

45 return error_validation_response() 

46 

47 clusters = fetch_all_clusters() 

48 if len(clusters) < 1: 

49 return error_validation_response({"message": "No clusters are available"}) 

50 

51 filtered_clusters: list[ClusterModel] = [] 

52 if "batch_options" in payload_dict and "cluster_name" in payload_dict["batch_options"]: 

53 cluster_name = payload_dict["batch_options"]["cluster_name"] 

54 filtered_clusters = [cluster for cluster in clusters if cluster.cluster_name == cluster_name] 

55 cluster = filtered_clusters[0] if len(filtered_clusters) > 0 else clusters[0] 

56 

57 # create a new simulation in the database, not waiting for the job to finish 

58 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.BATCH.value 

59 # skipcq: PYL-E1123 

60 simulation = BatchSimulationModel(user_id=user.id, 

61 cluster_id=cluster.id, 

62 job_id=job_id, 

63 sim_type=payload_dict["sim_type"], 

64 input_type=input_type, 

65 title=payload_dict.get("title", '')) 

66 add_object_to_db(simulation) 

67 update_key = encode_simulation_auth_token(simulation.id) 

68 

69 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type) 

70 

71 submit_job.delay(payload_dict=payload_dict, 

72 files_dict=input_dict["input_files"], 

73 userId=user.id, 

74 clusterId=cluster.id, 

75 sim_id=simulation.id, 

76 update_key=update_key) 

77 

78 for i in range(payload_dict["ntasks"]): 

79 task = BatchTaskModel(simulation_id=simulation.id, task_id=str(i + 1)) 

80 add_object_to_db(task, False) 

81 

82 input_model = InputModel(simulation_id=simulation.id) 

83 input_model.data = input_dict 

84 add_object_to_db(input_model) 

85 if simulation.update_state({"job_state": EntityState.PENDING.value}): 

86 make_commit_to_db() 

87 

88 return yaptide_response(message="Job waiting for submission", code=202, content={'job_id': simulation.job_id}) 

89 

90 class APIParametersSchema(Schema): 

91 """Class specifies API parameters""" 

92 

93 job_id = fields.String() 

94 

95 @staticmethod 

96 @requires_auth() 

97 def get(user: KeycloakUserModel): 

98 """Method geting job's result""" 

99 if not isinstance(user, KeycloakUserModel): 

100 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

101 

102 schema = JobsBatch.APIParametersSchema() 

103 errors: dict[str, list[str]] = schema.validate(request.args) 

104 if errors: 

105 return error_validation_response(content=errors) 

106 params_dict: dict = schema.load(request.args) 

107 

108 job_id: str = params_dict["job_id"] 

109 

110 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

111 if not is_owned: 

112 return yaptide_response(message=error_message, code=res_code) 

113 simulation = fetch_batch_simulation_by_job_id(job_id=job_id) 

114 

115 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id) 

116 

117 job_tasks_status = [task.get_status_dict() for task in tasks] 

118 

119 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value): 

120 return yaptide_response(message=f"Job state: {simulation.job_state}", 

121 code=200, 

122 content={ 

123 "job_state": simulation.job_state, 

124 "job_tasks_status": job_tasks_status, 

125 }) 

126 

127 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id) 

128 

129 job_info = get_job_status(simulation=simulation, user=user, cluster=cluster) 

130 update_simulation_state(simulation=simulation, update_dict=job_info) 

131 

132 job_info.pop("end_time", None) 

133 job_info["job_tasks_status"] = job_tasks_status 

134 

135 return yaptide_response(message="", code=200, content=job_info) 

136 

137 @staticmethod 

138 @requires_auth() 

139 def delete(user: KeycloakUserModel): 

140 """Method canceling job""" 

141 if not isinstance(user, KeycloakUserModel): 

142 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

143 

144 schema = JobsBatch.APIParametersSchema() 

145 errors: dict[str, list[str]] = schema.validate(request.args) 

146 if errors: 

147 return error_validation_response(content=errors) 

148 params_dict: dict = schema.load(request.args) 

149 

150 job_id: str = params_dict["job_id"] 

151 

152 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

153 if not is_owned: 

154 return yaptide_response(message=error_message, code=res_code) 

155 

156 simulation = fetch_batch_simulation_by_job_id(job_id=job_id) 

157 

158 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value, 

159 EntityState.UNKNOWN.value): 

160 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state", 

161 code=200, 

162 content={ 

163 "job_state": simulation.job_state, 

164 }) 

165 

166 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id) 

167 

168 result, status_code = delete_job(simulation=simulation, user=user, cluster=cluster) 

169 if status_code != 200: 

170 return error_internal_response(content=result) 

171 

172 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value}) 

173 

174 tasks = fetch_batch_tasks_by_sim_id(sim_id=simulation.id) 

175 

176 for task in tasks: 

177 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value}) 

178 

179 return yaptide_response(message="", code=status_code, content=result) 

180 

181 

182class Clusters(Resource): 

183 """Class responsible for returning user's available clusters""" 

184 

185 @staticmethod 

186 @requires_auth() 

187 def get(user: KeycloakUserModel): 

188 """Method returning clusters""" 

189 if not isinstance(user, KeycloakUserModel): 

190 return yaptide_response(message="User is not allowed to use this endpoint", code=403) 

191 

192 clusters = fetch_all_clusters() 

193 

194 result = {'clusters': [{'cluster_name': cluster.cluster_name} for cluster in clusters]} 

195 return yaptide_response(message='Available clusters', code=200, content=result)