Coverage for yaptide/routes/celery_routes.py: 40%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import logging 

2import uuid 

3from collections import Counter 

4from datetime import datetime 

5 

6from flask import request 

7from flask_restful import Resource 

8from marshmallow import Schema, fields 

9 

10from yaptide.celery.tasks import convert_input_files 

11from yaptide.celery.utils.manage_tasks import (cancel_job, get_job_results, 

12 run_job) 

13from yaptide.persistence.db_methods import (add_object_to_db, 

14 fetch_celery_simulation_by_job_id, 

15 fetch_celery_tasks_by_sim_id, 

16 fetch_estimators_by_sim_id, 

17 fetch_pages_by_estimator_id, 

18 make_commit_to_db, 

19 update_simulation_state, 

20 update_task_state) 

21from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, 

22 EstimatorModel, InputModel, PageModel, 

23 UserModel) 

24from yaptide.routes.utils.decorators import requires_auth 

25from yaptide.routes.utils.response_templates import (error_internal_response, 

26 error_validation_response, 

27 yaptide_response) 

28from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict 

29from yaptide.utils.enums import EntityState, PlatformType 

30 

31 

32class JobsDirect(Resource): 

33 """Class responsible for simulations run directly with celery""" 

34 

35 @staticmethod 

36 @requires_auth() 

37 def post(user: UserModel): 

38 """Submit simulation job to celery""" 

39 payload_dict: dict = request.get_json(force=True) 

40 if not payload_dict: 

41 return yaptide_response(message="No JSON in body", code=400) 

42 

43 required_keys = {"sim_type", "ntasks", "input_type"} 

44 

45 if required_keys != required_keys.intersection(set(payload_dict.keys())): 

46 diff = required_keys.difference(set(payload_dict.keys())) 

47 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400) 

48 

49 input_type = determine_input_type(payload_dict) 

50 

51 if input_type is None: 

52 return error_validation_response() 

53 

54 # create a new simulation in the database, not waiting for the job to finish 

55 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid.uuid4()) + PlatformType.DIRECT.value 

56 simulation = CelerySimulationModel(user_id=user.id, 

57 job_id=job_id, 

58 sim_type=payload_dict["sim_type"], 

59 input_type=input_type, 

60 title=payload_dict.get("title", '')) 

61 update_key = str(uuid.uuid4()) 

62 simulation.set_update_key(update_key) 

63 add_object_to_db(simulation) 

64 logging.info("Simulation %d created and inserted into DB", simulation.id) 

65 logging.debug("Update key set to %s", update_key) 

66 

67 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type) 

68 

69 # submit the asynchronous job to celery 

70 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"], 

71 payload_dict["sim_type"]) 

72 

73 # create tasks in the database in the default PENDING state 

74 for i in range(payload_dict["ntasks"]): 

75 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i) 

76 add_object_to_db(task, make_commit=False) 

77 

78 input_model = InputModel(simulation_id=simulation.id) 

79 input_model.data = input_dict 

80 add_object_to_db(input_model) 

81 if simulation.update_state({"job_state": EntityState.PENDING.value}): 

82 make_commit_to_db() 

83 

84 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id}) 

85 

86 class APIParametersSchema(Schema): 

87 """Class specifies API parameters for GET and DELETE request""" 

88 

89 job_id = fields.String() 

90 

91 @staticmethod 

92 @requires_auth() 

93 def get(user: UserModel): 

94 """Method returning job status and results""" 

95 # validate request parameters and handle errors 

96 schema = JobsDirect.APIParametersSchema() 

97 errors: dict[str, list[str]] = schema.validate(request.args) 

98 if errors: 

99 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

100 param_dict: dict = schema.load(request.args) 

101 

102 # get job_id from request parameters and check if user owns this job 

103 job_id = param_dict['job_id'] 

104 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

105 if not is_owned: 

106 return yaptide_response(message=error_message, code=res_code) 

107 

108 # find appropriate simulation in the database 

109 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

110 

111 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

112 

113 job_tasks_status = [task.get_status_dict() for task in tasks] 

114 

115 if simulation.job_state in (EntityState.COMPLETED.value, 

116 EntityState.FAILED.value): 

117 return yaptide_response(message=f"Job state: {simulation.job_state}", 

118 code=200, 

119 content={ 

120 "job_state": simulation.job_state, 

121 "job_tasks_status": job_tasks_status, 

122 }) 

123 

124 job_info = { 

125 "job_state": simulation.job_state 

126 } 

127 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

128 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

129 job_info["job_state"] = EntityState.PENDING.value 

130 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

131 job_info["job_state"] = EntityState.FAILED.value 

132 elif status_counter[EntityState.RUNNING.value] > 0: 

133 job_info["job_state"] = EntityState.RUNNING.value 

134 

135 # if simulation is not found, return error 

136 update_simulation_state(simulation=simulation, update_dict=job_info) 

137 

138 job_info["job_tasks_status"] = job_tasks_status 

139 

140 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

141 

142 @staticmethod 

143 @requires_auth() 

144 def delete(user: UserModel): 

145 """Method canceling simulation and returning status of this action""" 

146 schema = JobsDirect.APIParametersSchema() 

147 errors: dict[str, list[str]] = schema.validate(request.args) 

148 if errors: 

149 return error_validation_response(content=errors) 

150 params_dict: dict = schema.load(request.args) 

151 

152 job_id = params_dict['job_id'] 

153 

154 is_owned, error_message, res_code = check_if_job_is_owned_and_exist( 

155 job_id=job_id, user=user) 

156 if not is_owned: 

157 return yaptide_response(message=error_message, code=res_code) 

158 

159 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

160 

161 if simulation.job_state in (EntityState.COMPLETED.value, 

162 EntityState.FAILED.value, 

163 EntityState.CANCELED.value, 

164 EntityState.UNKNOWN.value): 

165 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state", 

166 code=200, 

167 content={ 

168 "job_state": simulation.job_state, 

169 }) 

170 

171 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

172 

173 celery_ids = [task.celery_id for task in tasks] 

174 

175 result: dict = cancel_job(merge_id=simulation.merge_id, celery_ids=celery_ids) 

176 

177 if "merge" in result: 

178 update_simulation_state(simulation=simulation, update_dict=result["merge"]) 

179 for i, task in enumerate(tasks): 

180 update_task_state(task=task, update_dict=result["tasks"][i]) 

181 

182 return yaptide_response(message="", code=200, content=result) 

183 

184 return error_internal_response() 

185 

186 

187class ResultsDirect(Resource): 

188 """Class responsible for returning simulation results""" 

189 

190 class APIParametersSchema(Schema): 

191 """Class specifies API parameters""" 

192 

193 job_id = fields.String() 

194 

195 @staticmethod 

196 @requires_auth() 

197 def get(user: UserModel): 

198 """Method returning job status and results""" 

199 schema = ResultsDirect.APIParametersSchema() 

200 errors: dict[str, list[str]] = schema.validate(request.args) 

201 if errors: 

202 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

203 param_dict: dict = schema.load(request.args) 

204 

205 job_id = param_dict['job_id'] 

206 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

207 if not is_owned: 

208 return yaptide_response(message=error_message, code=res_code) 

209 

210 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

211 

212 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id) 

213 if len(estimators) > 0: 

214 logging.debug("Returning results from database") 

215 result_estimators = [] 

216 for estimator in estimators: 

217 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id) 

218 estimator_dict = { 

219 "metadata": estimator.data, 

220 "name": estimator.name, 

221 "pages": [page.data for page in pages] 

222 } 

223 result_estimators.append(estimator_dict) 

224 return yaptide_response(message=f"Results for job: {job_id}", 

225 code=200, content={"estimators": result_estimators}) 

226 

227 result: dict = get_job_results(job_id=job_id) 

228 if "estimators" not in result: 

229 logging.debug("Results for job %s are unavailable", job_id) 

230 return yaptide_response(message="Results are unavailable", code=404, content=result) 

231 

232 for estimator_dict in result["estimators"]: 

233 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

234 estimator.data = estimator_dict["metadata"] 

235 add_object_to_db(estimator) 

236 for page_dict in estimator_dict["pages"]: 

237 page = PageModel(estimator_id=estimator.id, 

238 page_number=int(page_dict["metadata"]["page_number"])) 

239 page.data = page_dict 

240 add_object_to_db(page, False) 

241 make_commit_to_db() 

242 

243 logging.debug("Returning results from Celery") 

244 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result) 

245 

246 

247class ConvertResource(Resource): 

248 """Class responsible for returning input_model files converted from front JSON""" 

249 

250 @staticmethod 

251 @requires_auth() 

252 def post(_: UserModel): 

253 """Method handling input_model files convertion""" 

254 payload_dict: dict = request.get_json(force=True) 

255 if not payload_dict: 

256 return yaptide_response(message="No JSON in body", code=400) 

257 

258 # Rework in later PRs to match pattern from jobs endpoint 

259 job = convert_input_files.delay(payload_dict=payload_dict) 

260 result: dict = job.wait() 

261 

262 return yaptide_response(message="Converted Input Files", code=200, content=result)