Coverage for yaptide/routes/celery_routes.py: 42%

156 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4 

5from flask import request 

6from flask_restful import Resource 

7from marshmallow import Schema, fields 

8from uuid import uuid4 

9 

10from yaptide.celery.simulation_worker import celery_app 

11from yaptide.celery.tasks import convert_input_files 

12from yaptide.celery.utils.manage_tasks import (get_job_results, run_job) 

13from yaptide.persistence.db_methods import (add_object_to_db, fetch_celery_simulation_by_job_id, 

14 fetch_celery_tasks_by_sim_id, fetch_estimators_by_sim_id, 

15 fetch_pages_by_estimator_id, make_commit_to_db, update_simulation_state, 

16 update_task_state) 

17from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, EstimatorModel, InputModel, PageModel, 

18 UserModel) 

19from yaptide.routes.utils.decorators import requires_auth 

20from yaptide.routes.utils.response_templates import (error_validation_response, yaptide_response) 

21from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict 

22from yaptide.routes.utils.tokens import encode_simulation_auth_token 

23from yaptide.utils.enums import EntityState, PlatformType 

24from yaptide.utils.helper_tasks import terminate_unfinished_tasks 

25 

26 

27class JobsDirect(Resource): 

28 """Class responsible for simulations run directly with celery""" 

29 

30 @staticmethod 

31 @requires_auth() 

32 def post(user: UserModel): 

33 """Submit simulation job to celery""" 

34 payload_dict: dict = request.get_json(force=True) 

35 if not payload_dict: 

36 return yaptide_response(message="No JSON in body", code=400) 

37 

38 required_keys = {"sim_type", "ntasks", "input_type"} 

39 

40 if required_keys != required_keys.intersection(set(payload_dict.keys())): 

41 diff = required_keys.difference(set(payload_dict.keys())) 

42 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400) 

43 

44 input_type = determine_input_type(payload_dict) 

45 

46 if input_type is None: 

47 return error_validation_response() 

48 

49 # create a new simulation in the database, not waiting for the job to finish 

50 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid4()) + PlatformType.DIRECT.value 

51 simulation = CelerySimulationModel(user_id=user.id, 

52 job_id=job_id, 

53 sim_type=payload_dict["sim_type"], 

54 input_type=input_type, 

55 title=payload_dict.get("title", '')) 

56 add_object_to_db(simulation) 

57 update_key = encode_simulation_auth_token(simulation.id) 

58 logging.info("Simulation %d created and inserted into DB", simulation.id) 

59 logging.debug("Update key set to %s", update_key) 

60 

61 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type) 

62 # create tasks in the database in the default PENDING state 

63 celery_ids = [str(uuid4()) for _ in range(payload_dict["ntasks"])] 

64 for i in range(payload_dict["ntasks"]): 

65 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i, celery_id=celery_ids[i]) 

66 add_object_to_db(task, make_commit=False) 

67 make_commit_to_db() 

68 

69 # submit the asynchronous job to celery 

70 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"], 

71 celery_ids, payload_dict["sim_type"]) 

72 

73 input_model = InputModel(simulation_id=simulation.id) 

74 input_model.data = input_dict 

75 add_object_to_db(input_model) 

76 if simulation.update_state({"job_state": EntityState.PENDING.value}): 

77 make_commit_to_db() 

78 

79 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id}) 

80 

81 class APIParametersSchema(Schema): 

82 """Class specifies API parameters for GET and DELETE request""" 

83 

84 job_id = fields.String() 

85 

86 @staticmethod 

87 @requires_auth() 

88 def get(user: UserModel): 

89 """Method returning job status and results""" 

90 # validate request parameters and handle errors 

91 schema = JobsDirect.APIParametersSchema() 

92 errors: dict[str, list[str]] = schema.validate(request.args) 

93 if errors: 

94 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

95 param_dict: dict = schema.load(request.args) 

96 

97 # get job_id from request parameters and check if user owns this job 

98 job_id = param_dict['job_id'] 

99 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

100 if not is_owned: 

101 return yaptide_response(message=error_message, code=res_code) 

102 

103 # find appropriate simulation in the database 

104 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

105 

106 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

107 

108 job_tasks_status = [task.get_status_dict() for task in tasks] 

109 

110 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value): 

111 return yaptide_response(message=f"Job state: {simulation.job_state}", 

112 code=200, 

113 content={ 

114 "job_state": simulation.job_state, 

115 "job_tasks_status": job_tasks_status, 

116 }) 

117 

118 job_info = {"job_state": simulation.job_state} 

119 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

120 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

121 job_info["job_state"] = EntityState.PENDING.value 

122 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

123 job_info["job_state"] = EntityState.FAILED.value 

124 elif status_counter[EntityState.RUNNING.value] > 0: 

125 job_info["job_state"] = EntityState.RUNNING.value 

126 

127 # if simulation is not found, return error 

128 update_simulation_state(simulation=simulation, update_dict=job_info) 

129 

130 job_info["job_tasks_status"] = job_tasks_status 

131 

132 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

133 

134 @staticmethod 

135 @requires_auth() 

136 def delete(user: UserModel): 

137 """Method canceling simulation and returning status of this action""" 

138 schema = JobsDirect.APIParametersSchema() 

139 errors: dict[str, list[str]] = schema.validate(request.args) 

140 if errors: 

141 return error_validation_response(content=errors) 

142 params_dict: dict = schema.load(request.args) 

143 

144 job_id = params_dict['job_id'] 

145 

146 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

147 if not is_owned: 

148 return yaptide_response(message=error_message, code=res_code) 

149 

150 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

151 

152 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value, 

153 EntityState.UNKNOWN.value): 

154 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state", 

155 code=200, 

156 content={ 

157 "job_state": simulation.job_state, 

158 }) 

159 

160 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

161 celery_ids = [ 

162 task.celery_id for task in tasks 

163 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value, EntityState.UNKNOWN.value] 

164 ] 

165 

166 # The merge_id is canceled first because merge task starts after run simulation tasks are finished/canceled. 

167 # We don't want it to run accidentally. 

168 celery_app.control.revoke(simulation.merge_id, terminate=True, signal="SIGINT") 

169 celery_app.control.revoke(celery_ids, terminate=True, signal="SIGINT") 

170 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value}) 

171 for task in tasks: 

172 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value]: 

173 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value}) 

174 

175 terminate_unfinished_tasks.delay(simulation_id=simulation.id) 

176 return yaptide_response(message="Cancelled sucessfully", code=200) 

177 

178 

179class ResultsDirect(Resource): 

180 """Class responsible for returning simulation results""" 

181 

182 class APIParametersSchema(Schema): 

183 """Class specifies API parameters""" 

184 

185 job_id = fields.String() 

186 

187 @staticmethod 

188 @requires_auth() 

189 def get(user: UserModel): 

190 """Method returning job status and results""" 

191 schema = ResultsDirect.APIParametersSchema() 

192 errors: dict[str, list[str]] = schema.validate(request.args) 

193 if errors: 

194 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

195 param_dict: dict = schema.load(request.args) 

196 

197 job_id = param_dict['job_id'] 

198 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

199 if not is_owned: 

200 return yaptide_response(message=error_message, code=res_code) 

201 

202 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

203 

204 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id) 

205 if len(estimators) > 0: 

206 logging.debug("Returning results from database") 

207 result_estimators = [] 

208 for estimator in estimators: 

209 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id) 

210 estimator_dict = { 

211 "metadata": estimator.data, 

212 "name": estimator.name, 

213 "pages": [page.data for page in pages] 

214 } 

215 result_estimators.append(estimator_dict) 

216 return yaptide_response(message=f"Results for job: {job_id}", 

217 code=200, 

218 content={"estimators": result_estimators}) 

219 

220 result: dict = get_job_results(job_id=job_id) 

221 if "estimators" not in result: 

222 logging.debug("Results for job %s are unavailable", job_id) 

223 return yaptide_response(message="Results are unavailable", code=404, content=result) 

224 

225 for estimator_dict in result["estimators"]: 

226 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

227 estimator.data = estimator_dict["metadata"] 

228 add_object_to_db(estimator) 

229 for page_dict in estimator_dict["pages"]: 

230 page = PageModel(estimator_id=estimator.id, page_number=int(page_dict["metadata"]["page_number"])) 

231 page.data = page_dict 

232 add_object_to_db(page, False) 

233 make_commit_to_db() 

234 

235 logging.debug("Returning results from Celery") 

236 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result) 

237 

238 

239class ConvertResource(Resource): 

240 """Class responsible for returning input_model files converted from front JSON""" 

241 

242 @staticmethod 

243 @requires_auth() 

244 def post(_: UserModel): 

245 """Method handling input_model files convertion""" 

246 payload_dict: dict = request.get_json(force=True) 

247 if not payload_dict: 

248 return yaptide_response(message="No JSON in body", code=400) 

249 

250 # Rework in later PRs to match pattern from jobs endpoint 

251 job = convert_input_files.delay(payload_dict=payload_dict) 

252 result: dict = job.wait() 

253 

254 return yaptide_response(message="Converted Input Files", code=200, content=result)