Coverage for yaptide/routes/common_sim_routes.py: 59%

174 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4 

5from flask import request 

6from flask_restful import Resource 

7from marshmallow import Schema, fields 

8 

9from yaptide.batch.batch_methods import get_job_results 

10from yaptide.persistence.db_methods import ( 

11 add_object_to_db, fetch_cluster_by_id, 

12 fetch_estimator_by_sim_id_and_est_name, fetch_estimators_by_sim_id, 

13 fetch_input_by_sim_id, fetch_logfiles_by_sim_id, 

14 fetch_page_by_est_id_and_page_number, fetch_pages_by_estimator_id, 

15 fetch_simulation_by_job_id, fetch_simulation_by_sim_id, 

16 fetch_tasks_by_sim_id, make_commit_to_db, update_simulation_state) 

17from yaptide.persistence.models import (BatchSimulationModel, EstimatorModel, 

18 LogfilesModel, PageModel, UserModel) 

19from yaptide.routes.utils.decorators import requires_auth 

20from yaptide.routes.utils.response_templates import yaptide_response 

21from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist 

22from yaptide.utils.enums import EntityState 

23 

24 

25class JobsResource(Resource): 

26 """Class responsible for managing common jobs""" 

27 

28 class APIParametersSchema(Schema): 

29 """Class specifies API parameters for GET and DELETE request""" 

30 

31 job_id = fields.String() 

32 

33 @staticmethod 

34 @requires_auth() 

35 def get(user: UserModel): 

36 """Method returning info about job""" 

37 schema = JobsResource.APIParametersSchema() 

38 errors: dict[str, list[str]] = schema.validate(request.args) 

39 if errors: 

40 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

41 param_dict: dict = schema.load(request.args) 

42 

43 # get job_id from request parameters and check if user owns this job 

44 job_id = param_dict['job_id'] 

45 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

46 if not is_owned: 

47 return yaptide_response(message=error_message, code=res_code) 

48 

49 simulation = fetch_simulation_by_job_id(job_id=job_id) 

50 if simulation.job_state == EntityState.UNKNOWN.value: 

51 return yaptide_response(message="Job state is unknown", 

52 code=200, 

53 content={"job_state": simulation.job_state}) 

54 

55 tasks = fetch_tasks_by_sim_id(sim_id=simulation.id) 

56 

57 job_tasks_status = [task.get_status_dict() for task in tasks] 

58 

59 if simulation.job_state in (EntityState.COMPLETED.value, 

60 EntityState.FAILED.value): 

61 return yaptide_response(message=f"Job state: {simulation.job_state}", 

62 code=200, 

63 content={ 

64 "job_state": simulation.job_state, 

65 "job_tasks_status": job_tasks_status, 

66 }) 

67 

68 job_info = { 

69 "job_state": simulation.job_state 

70 } 

71 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

72 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

73 job_info["job_state"] = EntityState.PENDING.value 

74 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

75 job_info["job_state"] = EntityState.FAILED.value 

76 elif status_counter[EntityState.RUNNING.value] > 0: 

77 job_info["job_state"] = EntityState.RUNNING.value 

78 

79 update_simulation_state(simulation=simulation, update_dict=job_info) 

80 

81 job_info["job_tasks_status"] = job_tasks_status 

82 

83 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

84 

85 

86class ResultsResource(Resource): 

87 """Class responsible for managing results""" 

88 

89 @staticmethod 

90 def post(): 

91 """ 

92 Method for saving results 

93 Used by the jobs at the end of simulation 

94 Structure required by this method to work properly: 

95 { 

96 "simulation_id": <int>, 

97 "update_key": <string>, 

98 "estimators": <dict> 

99 } 

100 """ 

101 payload_dict: dict = request.get_json(force=True) 

102 if {"simulation_id", "update_key", "estimators"} != set(payload_dict.keys()): 

103 return yaptide_response(message="Incomplete JSON data", code=400) 

104 

105 sim_id = payload_dict["simulation_id"] 

106 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

107 

108 if not simulation: 

109 return yaptide_response(message="Simulation does not exist", code=400) 

110 

111 if not simulation.check_update_key(payload_dict["update_key"]): 

112 return yaptide_response(message="Invalid update key", code=400) 

113 

114 for estimator_dict in payload_dict["estimators"]: 

115 # We forsee the possibility of the estimator being created earlier as element of partial results 

116 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_dict["name"]) 

117 

118 if not estimator: 

119 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

120 estimator.data = estimator_dict["metadata"] 

121 add_object_to_db(estimator) 

122 

123 for page_dict in estimator_dict["pages"]: 

124 page = fetch_page_by_est_id_and_page_number( 

125 est_id=estimator.id, page_number=int(page_dict["metadata"]["page_number"])) 

126 

127 page_existed = bool(page) 

128 if not page_existed: 

129 # create new page 

130 page = PageModel(page_number=int(page_dict["metadata"]["page_number"]), estimator_id=estimator.id) 

131 # we always update the data 

132 page.data = page_dict 

133 if not page_existed: 

134 # if page was created, we add it to the session 

135 add_object_to_db(page, False) 

136 

137 make_commit_to_db() 

138 logging.debug("Marking simulation as completed") 

139 update_dict = { 

140 "job_state": EntityState.COMPLETED.value, 

141 "end_time": datetime.utcnow().isoformat(sep=" ") 

142 } 

143 update_simulation_state(simulation=simulation, update_dict=update_dict) 

144 return yaptide_response(message="Results saved", code=202) 

145 

146 class APIParametersSchema(Schema): 

147 """Class specifies API parameters""" 

148 

149 job_id = fields.String() 

150 

151 @staticmethod 

152 @requires_auth() 

153 def get(user: UserModel): 

154 """Method returning job status and results""" 

155 schema = ResultsResource.APIParametersSchema() 

156 errors: dict[str, list[str]] = schema.validate(request.args) 

157 if errors: 

158 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

159 param_dict: dict = schema.load(request.args) 

160 

161 job_id = param_dict['job_id'] 

162 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

163 if not is_owned: 

164 return yaptide_response(message=error_message, code=res_code) 

165 

166 simulation = fetch_simulation_by_job_id(job_id=job_id) 

167 

168 estimators = fetch_estimators_by_sim_id(sim_id=simulation.id) 

169 if len(estimators) == 0: 

170 if not isinstance(simulation, BatchSimulationModel): # also CODE TO REMOVE 

171 return yaptide_response(message="Results are unavailable", code=404) 

172 # Code below is for backward compatibility with old method of saving results 

173 # later on we are going to remove it because it's functionality will be covered 

174 # by the post method 

175 # BEGIN CODE TO REMOVE 

176 

177 cluster = fetch_cluster_by_id(cluster_id=simulation.cluster_id) 

178 

179 result: dict = get_job_results(simulation=simulation, user=user, cluster=cluster) 

180 if "estimators" not in result: 

181 logging.debug("Results for job %s are unavailable", job_id) 

182 return yaptide_response(message="Results are unavailable", code=404, content=result) 

183 

184 for estimator_dict in result["estimators"]: 

185 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

186 estimator.data = estimator_dict["metadata"] 

187 add_object_to_db(estimator) 

188 for page_dict in estimator_dict["pages"]: 

189 page = PageModel(estimator_id=estimator.id, 

190 page_number=int(page_dict["metadata"]["page_number"])) 

191 page.data = page_dict 

192 add_object_to_db(page, False) 

193 make_commit_to_db() 

194 estimators = fetch_estimators_by_sim_id(sim_id=simulation.id) 

195 # END CODE TO REMOVE 

196 

197 logging.debug("Returning results from database") 

198 result_estimators = [] 

199 for estimator in estimators: 

200 pages = fetch_pages_by_estimator_id(est_id=estimator.id) 

201 estimator_dict = { 

202 "metadata": estimator.data, 

203 "name": estimator.name, 

204 "pages": [page.data for page in pages] 

205 } 

206 result_estimators.append(estimator_dict) 

207 return yaptide_response(message=f"Results for job: {job_id}", code=200, 

208 content={"estimators": result_estimators}) 

209 

210 

211class InputsResource(Resource): 

212 """Class responsible for returning simulation input""" 

213 

214 class APIParametersSchema(Schema): 

215 """Class specifies API parameters""" 

216 

217 job_id = fields.String() 

218 

219 @staticmethod 

220 @requires_auth() 

221 def get(user: UserModel): 

222 """Method returning simulation input""" 

223 schema = InputsResource.APIParametersSchema() 

224 errors: dict[str, list[str]] = schema.validate(request.args) 

225 if errors: 

226 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

227 param_dict: dict = schema.load(request.args) 

228 job_id = param_dict['job_id'] 

229 

230 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

231 if not is_owned: 

232 return yaptide_response(message=error_message, code=res_code) 

233 

234 simulation = fetch_simulation_by_job_id(job_id=job_id) 

235 

236 input_model = fetch_input_by_sim_id(sim_id=simulation.id) 

237 if not input_model: 

238 return yaptide_response(message="Input of simulation is unavailable", code=404) 

239 

240 return yaptide_response(message="Input of simulation", code=200, content={"input": input_model.data}) 

241 

242 

243class LogfilesResource(Resource): 

244 """Class responsible for managing logfiles""" 

245 

246 @staticmethod 

247 def post(): 

248 """ 

249 Method for saving logfiles 

250 Used by the jobs when the simulation fails 

251 Structure required by this method to work properly: 

252 { 

253 "simulation_id": <int>, 

254 "update_key": <string>, 

255 "logfiles": <dict> 

256 } 

257 """ 

258 payload_dict: dict = request.get_json(force=True) 

259 if {"simulation_id", "update_key", "logfiles"} != set(payload_dict.keys()): 

260 return yaptide_response(message="Incomplete JSON data", code=400) 

261 

262 sim_id = payload_dict["simulation_id"] 

263 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

264 

265 if not simulation: 

266 return yaptide_response(message="Simulation does not exist", code=400) 

267 

268 if not simulation.check_update_key(payload_dict["update_key"]): 

269 return yaptide_response(message="Invalid update key", code=400) 

270 

271 logfiles = LogfilesModel(simulation_id=simulation.id) 

272 logfiles.data = payload_dict["logfiles"] 

273 add_object_to_db(logfiles) 

274 

275 return yaptide_response(message="Log files saved", code=202) 

276 

277 class APIParametersSchema(Schema): 

278 """Class specifies API parameters""" 

279 

280 job_id = fields.String() 

281 

282 @staticmethod 

283 @requires_auth() 

284 def get(user: UserModel): 

285 """Method returning job status and results""" 

286 schema = ResultsResource.APIParametersSchema() 

287 errors: dict[str, list[str]] = schema.validate(request.args) 

288 if errors: 

289 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

290 param_dict: dict = schema.load(request.args) 

291 

292 job_id = param_dict['job_id'] 

293 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

294 if not is_owned: 

295 return yaptide_response(message=error_message, code=res_code) 

296 

297 simulation = fetch_simulation_by_job_id(job_id=job_id) 

298 

299 logfile = fetch_logfiles_by_sim_id(sim_id=simulation.id) 

300 if not logfile: 

301 return yaptide_response(message="Logfiles are unavailable", code=404) 

302 

303 logging.debug("Returning logfiles from database") 

304 

305 return yaptide_response(message="Logfiles", code=200, content={"logfiles": logfile.data})