Coverage for yaptide/routes/common_sim_routes.py: 62%

191 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4from typing import Union 

5 

6from flask import request, current_app as app 

7from flask_restful import Resource 

8from marshmallow import Schema, fields 

9 

10from yaptide.persistence.db_methods import (add_object_to_db, fetch_estimator_by_sim_id_and_est_name, 

11 fetch_estimators_by_sim_id, fetch_input_by_sim_id, fetch_logfiles_by_sim_id, 

12 fetch_page_by_est_id_and_page_number, fetch_pages_by_estimator_id, 

13 fetch_simulation_by_job_id, fetch_simulation_by_sim_id, 

14 fetch_simulation_id_by_job_id, fetch_tasks_by_sim_id, make_commit_to_db, 

15 update_simulation_state) 

16from yaptide.persistence.models import (EstimatorModel, LogfilesModel, PageModel, UserModel) 

17from yaptide.routes.utils.decorators import requires_auth 

18from yaptide.routes.utils.response_templates import yaptide_response 

19from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist 

20from yaptide.routes.utils.tokens import decode_auth_token 

21from yaptide.utils.enums import EntityState 

22 

23 

24class JobsResource(Resource): 

25 """Class responsible for managing common jobs""" 

26 

27 class APIParametersSchema(Schema): 

28 """Class specifies API parameters for GET and DELETE request""" 

29 

30 job_id = fields.String() 

31 

32 @staticmethod 

33 @requires_auth() 

34 def get(user: UserModel): 

35 """Method returning info about job""" 

36 schema = JobsResource.APIParametersSchema() 

37 errors: dict[str, list[str]] = schema.validate(request.args) 

38 if errors: 

39 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

40 param_dict: dict = schema.load(request.args) 

41 

42 # get job_id from request parameters and check if user owns this job 

43 job_id = param_dict['job_id'] 

44 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

45 if not is_owned: 

46 return yaptide_response(message=error_message, code=res_code) 

47 

48 simulation = fetch_simulation_by_job_id(job_id=job_id) 

49 if simulation.job_state == EntityState.UNKNOWN.value: 

50 return yaptide_response(message="Job state is unknown", 

51 code=200, 

52 content={"job_state": simulation.job_state}) 

53 

54 tasks = fetch_tasks_by_sim_id(sim_id=simulation.id) 

55 

56 job_tasks_status = [task.get_status_dict() for task in tasks] 

57 

58 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value): 

59 return yaptide_response(message=f"Job state: {simulation.job_state}", 

60 code=200, 

61 content={ 

62 "job_state": simulation.job_state, 

63 "job_tasks_status": job_tasks_status, 

64 }) 

65 

66 job_info = {"job_state": simulation.job_state} 

67 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

68 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

69 job_info["job_state"] = EntityState.PENDING.value 

70 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

71 job_info["job_state"] = EntityState.FAILED.value 

72 elif status_counter[EntityState.RUNNING.value] > 0: 

73 job_info["job_state"] = EntityState.RUNNING.value 

74 

75 update_simulation_state(simulation=simulation, update_dict=job_info) 

76 

77 job_info["job_tasks_status"] = job_tasks_status 

78 

79 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

80 

81 @staticmethod 

82 def post(): 

83 """Handles requests for updating simulation informations in db""" 

84 payload_dict: dict = request.get_json(force=True) 

85 sim_id: int = payload_dict["sim_id"] 

86 app.logger.info(f"sim_id {sim_id}") 

87 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

88 

89 if not simulation: 

90 app.logger.info(f"sim_id {sim_id} simulation not found ") 

91 return yaptide_response(message=f"Simulation {sim_id} does not exist", code=501) 

92 update_simulation_state(simulation, payload_dict) 

93 if payload_dict["log"]: 

94 logfiles = LogfilesModel(simulation_id=simulation.id) 

95 logfiles.data = payload_dict["log"] 

96 add_object_to_db(logfiles) 

97 

98 return yaptide_response(message="Task updated", code=202) 

99 

100 

101def get_single_estimator(sim_id: int, estimator_name: str): 

102 """Retrieve a single estimator by simulation ID and estimator name""" 

103 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_name) 

104 if not estimator: 

105 return yaptide_response(message="Estimator not found", code=404) 

106 

107 pages = fetch_pages_by_estimator_id(est_id=estimator.id) 

108 estimator_dict = {"metadata": estimator.data, "name": estimator.name, "pages": [page.data for page in pages]} 

109 return yaptide_response(message=f"Estimator '{estimator_name}' for simulation: {sim_id}", 

110 code=200, 

111 content=estimator_dict) 

112 

113 

114def get_all_estimators(sim_id: int): 

115 """Retrieve all estimators for a given simulation ID""" 

116 estimators = fetch_estimators_by_sim_id(sim_id=sim_id) 

117 if len(estimators) == 0: 

118 return yaptide_response(message="Results are unavailable", code=404) 

119 

120 logging.debug("Returning results from database") 

121 result_estimators = [] 

122 for estimator in estimators: 

123 estimator_dict = { 

124 "metadata": estimator.data, 

125 "name": estimator.name, 

126 "pages": [page.data for page in estimator.pages] 

127 } 

128 result_estimators.append(estimator_dict) 

129 return yaptide_response(message=f"Results for simulation: {sim_id}", 

130 code=200, 

131 content={"estimators": result_estimators}) 

132 

133 

134class ResultsResource(Resource): 

135 """Class responsible for managing results""" 

136 

137 @staticmethod 

138 def post(): 

139 """ 

140 Method for saving results 

141 Used by the jobs at the end of simulation 

142 Structure required by this method to work properly: 

143 { 

144 "simulation_id": <int>, 

145 "update_key": <string>, 

146 "estimators": <dict> 

147 } 

148 """ 

149 payload_dict: dict = request.get_json(force=True) 

150 if {"simulation_id", "update_key", "estimators"} != set(payload_dict.keys()): 

151 return yaptide_response(message="Incomplete JSON data", code=400) 

152 

153 sim_id = payload_dict["simulation_id"] 

154 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

155 

156 if not simulation: 

157 return yaptide_response(message="Simulation does not exist", code=400) 

158 

159 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

160 if decoded_token != sim_id: 

161 return yaptide_response(message="Invalid update key", code=400) 

162 

163 for estimator_dict in payload_dict["estimators"]: 

164 # We forsee the possibility of the estimator being created earlier as element of partial results 

165 estimator = fetch_estimator_by_sim_id_and_est_name(sim_id=sim_id, est_name=estimator_dict["name"]) 

166 

167 if not estimator: 

168 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

169 estimator.data = estimator_dict["metadata"] 

170 add_object_to_db(estimator) 

171 

172 for page_dict in estimator_dict["pages"]: 

173 page = fetch_page_by_est_id_and_page_number(est_id=estimator.id, 

174 page_number=int(page_dict["metadata"]["page_number"])) 

175 

176 page_existed = bool(page) 

177 if not page_existed: 

178 # create new page 

179 page = PageModel(page_number=int(page_dict["metadata"]["page_number"]), estimator_id=estimator.id) 

180 # we always update the data 

181 page.data = page_dict 

182 if not page_existed: 

183 # if page was created, we add it to the session 

184 add_object_to_db(page, False) 

185 

186 make_commit_to_db() 

187 logging.debug("Marking simulation as completed") 

188 update_dict = {"job_state": EntityState.COMPLETED.value, "end_time": datetime.utcnow().isoformat(sep=" ")} 

189 update_simulation_state(simulation=simulation, update_dict=update_dict) 

190 

191 logging.debug("Marking simulation tasks as completed") 

192 

193 return yaptide_response(message="Results saved", code=202) 

194 

195 class APIParametersSchema(Schema): 

196 """Class specifies API parameters""" 

197 

198 job_id = fields.String() 

199 estimator_name = fields.String(load_default=None) 

200 

201 @staticmethod 

202 @requires_auth() 

203 def get(user: UserModel): 

204 """Method returning job status and results. 

205 If `estimator_name` parameter is provided, 

206 the response will include results only for that specific estimator, 

207 otherwise it will return all estimators for the given job. 

208 """ 

209 schema = ResultsResource.APIParametersSchema() 

210 errors: dict[str, list[str]] = schema.validate(request.args) 

211 if errors: 

212 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

213 param_dict: dict = schema.load(request.args) 

214 

215 job_id = param_dict['job_id'] 

216 estimator_name = param_dict['estimator_name'] 

217 

218 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

219 if not is_owned: 

220 return yaptide_response(message=error_message, code=res_code) 

221 

222 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id) 

223 if not simulation_id: 

224 return yaptide_response(message="Simulation does not exist", code=404) 

225 

226 # if estimator name is provided, return specific estimator 

227 if estimator_name: 

228 return get_single_estimator(sim_id=simulation_id, estimator_name=estimator_name) 

229 

230 return get_all_estimators(sim_id=simulation_id) 

231 

232 

233class InputsResource(Resource): 

234 """Class responsible for returning simulation input""" 

235 

236 class APIParametersSchema(Schema): 

237 """Class specifies API parameters""" 

238 

239 job_id = fields.String() 

240 

241 @staticmethod 

242 @requires_auth() 

243 def get(user: UserModel): 

244 """Method returning simulation input""" 

245 schema = InputsResource.APIParametersSchema() 

246 errors: dict[str, list[str]] = schema.validate(request.args) 

247 if errors: 

248 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

249 param_dict: dict = schema.load(request.args) 

250 job_id = param_dict['job_id'] 

251 

252 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

253 if not is_owned: 

254 return yaptide_response(message=error_message, code=res_code) 

255 

256 simulation = fetch_simulation_by_job_id(job_id=job_id) 

257 

258 input_model = fetch_input_by_sim_id(sim_id=simulation.id) 

259 if not input_model: 

260 return yaptide_response(message="Input of simulation is unavailable", code=404) 

261 

262 return yaptide_response(message="Input of simulation", code=200, content={"input": input_model.data}) 

263 

264 

265class LogfilesResource(Resource): 

266 """Class responsible for managing logfiles""" 

267 

268 @staticmethod 

269 def post(): 

270 """ 

271 Method for saving logfiles 

272 Used by the jobs when the simulation fails 

273 Structure required by this method to work properly: 

274 { 

275 "simulation_id": <int>, 

276 "update_key": <string>, 

277 "logfiles": <dict> 

278 } 

279 """ 

280 payload_dict: dict = request.get_json(force=True) 

281 if {"simulation_id", "update_key", "logfiles"} != set(payload_dict.keys()): 

282 return yaptide_response(message="Incomplete JSON data", code=400) 

283 

284 sim_id = payload_dict["simulation_id"] 

285 simulation = fetch_simulation_by_sim_id(sim_id=sim_id) 

286 

287 if not simulation: 

288 return yaptide_response(message="Simulation does not exist", code=400) 

289 

290 decoded_token = decode_auth_token(payload_dict["update_key"], payload_key_to_return="simulation_id") 

291 if decoded_token != sim_id: 

292 return yaptide_response(message="Invalid update key", code=400) 

293 

294 logfiles = LogfilesModel(simulation_id=simulation.id) 

295 logfiles.data = payload_dict["logfiles"] 

296 add_object_to_db(logfiles) 

297 

298 return yaptide_response(message="Log files saved", code=202) 

299 

300 class APIParametersSchema(Schema): 

301 """Class specifies API parameters""" 

302 

303 job_id = fields.String() 

304 

305 @staticmethod 

306 @requires_auth() 

307 def get(user: UserModel): 

308 """Method returning job status and results""" 

309 schema = ResultsResource.APIParametersSchema() 

310 errors: dict[str, list[str]] = schema.validate(request.args) 

311 if errors: 

312 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

313 param_dict: dict = schema.load(request.args) 

314 

315 job_id = param_dict['job_id'] 

316 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

317 if not is_owned: 

318 return yaptide_response(message=error_message, code=res_code) 

319 

320 simulation = fetch_simulation_by_job_id(job_id=job_id) 

321 

322 logfile = fetch_logfiles_by_sim_id(sim_id=simulation.id) 

323 if not logfile: 

324 return yaptide_response(message="Logfiles are unavailable", code=404) 

325 

326 logging.debug("Returning logfiles from database") 

327 

328 return yaptide_response(message="Logfiles", code=200, content={"logfiles": logfile.data})