Coverage for yaptide/routes/celery_routes.py: 41%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-03-31 19:18 +0000

1import logging 

2from collections import Counter 

3from datetime import datetime 

4 

5from flask import request 

6from flask_restful import Resource 

7from marshmallow import Schema, fields 

8from uuid import uuid4 

9 

10from yaptide.celery.simulation_worker import celery_app 

11from yaptide.celery.utils.manage_tasks import (get_job_results, run_job) 

12from yaptide.persistence.db_methods import (add_object_to_db, fetch_celery_simulation_by_job_id, 

13 fetch_celery_tasks_by_sim_id, fetch_estimators_by_sim_id, 

14 fetch_pages_by_estimator_id, make_commit_to_db, update_simulation_state, 

15 update_task_state) 

16from yaptide.persistence.models import (CelerySimulationModel, CeleryTaskModel, EstimatorModel, InputModel, PageModel, 

17 UserModel) 

18from yaptide.routes.utils.decorators import requires_auth 

19from yaptide.routes.utils.response_templates import (error_validation_response, yaptide_response) 

20from yaptide.routes.utils.utils import check_if_job_is_owned_and_exist, determine_input_type, make_input_dict 

21from yaptide.routes.utils.tokens import encode_simulation_auth_token 

22from yaptide.utils.enums import EntityState, PlatformType 

23from yaptide.utils.helper_tasks import terminate_unfinished_tasks 

24 

25 

26class JobsDirect(Resource): 

27 """Class responsible for simulations run directly with celery""" 

28 

29 @staticmethod 

30 @requires_auth() 

31 def post(user: UserModel): 

32 """Submit simulation job to celery""" 

33 payload_dict: dict = request.get_json(force=True) 

34 if not payload_dict: 

35 return yaptide_response(message="No JSON in body", code=400) 

36 

37 required_keys = {"sim_type", "ntasks", "input_type"} 

38 

39 if required_keys != required_keys.intersection(set(payload_dict.keys())): 

40 diff = required_keys.difference(set(payload_dict.keys())) 

41 return yaptide_response(message=f"Missing keys in JSON payload: {diff}", code=400) 

42 

43 input_type = determine_input_type(payload_dict) 

44 

45 if input_type is None: 

46 return error_validation_response() 

47 

48 # create a new simulation in the database, not waiting for the job to finish 

49 job_id = datetime.now().strftime('%Y%m%d-%H%M%S-') + str(uuid4()) + PlatformType.DIRECT.value 

50 simulation = CelerySimulationModel(user_id=user.id, 

51 job_id=job_id, 

52 sim_type=payload_dict["sim_type"], 

53 input_type=input_type, 

54 title=payload_dict.get("title", '')) 

55 add_object_to_db(simulation) 

56 update_key = encode_simulation_auth_token(simulation.id) 

57 logging.info("Simulation %d created and inserted into DB", simulation.id) 

58 logging.debug("Update key set to %s", update_key) 

59 

60 input_dict = make_input_dict(payload_dict=payload_dict, input_type=input_type) 

61 # create tasks in the database in the default PENDING state 

62 celery_ids = [str(uuid4()) for _ in range(payload_dict["ntasks"])] 

63 for i in range(payload_dict["ntasks"]): 

64 task = CeleryTaskModel(simulation_id=simulation.id, task_id=i, celery_id=celery_ids[i]) 

65 add_object_to_db(task, make_commit=False) 

66 make_commit_to_db() 

67 

68 # submit the asynchronous job to celery 

69 simulation.merge_id = run_job(input_dict["input_files"], update_key, simulation.id, payload_dict["ntasks"], 

70 celery_ids, payload_dict["sim_type"]) 

71 

72 input_model = InputModel(simulation_id=simulation.id) 

73 input_model.data = input_dict 

74 add_object_to_db(input_model) 

75 if simulation.update_state({"job_state": EntityState.PENDING.value}): 

76 make_commit_to_db() 

77 

78 return yaptide_response(message="Task started", code=202, content={'job_id': simulation.job_id}) 

79 

80 class APIParametersSchema(Schema): 

81 """Class specifies API parameters for GET and DELETE request""" 

82 

83 job_id = fields.String() 

84 

85 @staticmethod 

86 @requires_auth() 

87 def get(user: UserModel): 

88 """Method returning job status and results""" 

89 # validate request parameters and handle errors 

90 schema = JobsDirect.APIParametersSchema() 

91 errors: dict[str, list[str]] = schema.validate(request.args) 

92 if errors: 

93 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

94 param_dict: dict = schema.load(request.args) 

95 

96 # get job_id from request parameters and check if user owns this job 

97 job_id = param_dict['job_id'] 

98 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

99 if not is_owned: 

100 return yaptide_response(message=error_message, code=res_code) 

101 

102 # find appropriate simulation in the database 

103 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

104 

105 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

106 

107 job_tasks_status = [task.get_status_dict() for task in tasks] 

108 

109 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value): 

110 return yaptide_response(message=f"Job state: {simulation.job_state}", 

111 code=200, 

112 content={ 

113 "job_state": simulation.job_state, 

114 "job_tasks_status": job_tasks_status, 

115 }) 

116 

117 job_info = {"job_state": simulation.job_state} 

118 status_counter = Counter([task["task_state"] for task in job_tasks_status]) 

119 if status_counter[EntityState.PENDING.value] == len(job_tasks_status): 

120 job_info["job_state"] = EntityState.PENDING.value 

121 elif status_counter[EntityState.FAILED.value] == len(job_tasks_status): 

122 job_info["job_state"] = EntityState.FAILED.value 

123 elif status_counter[EntityState.RUNNING.value] > 0: 

124 job_info["job_state"] = EntityState.RUNNING.value 

125 

126 # if simulation is not found, return error 

127 update_simulation_state(simulation=simulation, update_dict=job_info) 

128 

129 job_info["job_tasks_status"] = job_tasks_status 

130 

131 return yaptide_response(message=f"Job state: {job_info['job_state']}", code=200, content=job_info) 

132 

133 @staticmethod 

134 @requires_auth() 

135 def delete(user: UserModel): 

136 """Method canceling simulation and returning status of this action""" 

137 schema = JobsDirect.APIParametersSchema() 

138 errors: dict[str, list[str]] = schema.validate(request.args) 

139 if errors: 

140 return error_validation_response(content=errors) 

141 params_dict: dict = schema.load(request.args) 

142 

143 job_id = params_dict['job_id'] 

144 

145 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

146 if not is_owned: 

147 return yaptide_response(message=error_message, code=res_code) 

148 

149 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

150 

151 if simulation.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value, 

152 EntityState.UNKNOWN.value): 

153 return yaptide_response(message=f"Cannot cancel job which is in {simulation.job_state} state", 

154 code=200, 

155 content={ 

156 "job_state": simulation.job_state, 

157 }) 

158 

159 tasks = fetch_celery_tasks_by_sim_id(sim_id=simulation.id) 

160 celery_ids = [ 

161 task.celery_id for task in tasks 

162 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value, EntityState.UNKNOWN.value] 

163 ] 

164 

165 # The merge_id is canceled first because merge task starts after run simulation tasks are finished/canceled. 

166 # We don't want it to run accidentally. 

167 celery_app.control.revoke(simulation.merge_id, terminate=True, signal="SIGINT") 

168 celery_app.control.revoke(celery_ids, terminate=True, signal="SIGINT") 

169 update_simulation_state(simulation=simulation, update_dict={"job_state": EntityState.CANCELED.value}) 

170 for task in tasks: 

171 if task.task_state in [EntityState.PENDING.value, EntityState.RUNNING.value]: 

172 update_task_state(task=task, update_dict={"task_state": EntityState.CANCELED.value}) 

173 

174 terminate_unfinished_tasks.delay(simulation_id=simulation.id) 

175 return yaptide_response(message="Cancelled sucessfully", code=200) 

176 

177 

178class ResultsDirect(Resource): 

179 """Class responsible for returning simulation results""" 

180 

181 class APIParametersSchema(Schema): 

182 """Class specifies API parameters""" 

183 

184 job_id = fields.String() 

185 

186 @staticmethod 

187 @requires_auth() 

188 def get(user: UserModel): 

189 """Method returning job status and results""" 

190 schema = ResultsDirect.APIParametersSchema() 

191 errors: dict[str, list[str]] = schema.validate(request.args) 

192 if errors: 

193 return yaptide_response(message="Wrong parameters", code=400, content=errors) 

194 param_dict: dict = schema.load(request.args) 

195 

196 job_id = param_dict['job_id'] 

197 is_owned, error_message, res_code = check_if_job_is_owned_and_exist(job_id=job_id, user=user) 

198 if not is_owned: 

199 return yaptide_response(message=error_message, code=res_code) 

200 

201 simulation = fetch_celery_simulation_by_job_id(job_id=job_id) 

202 

203 estimators: list[EstimatorModel] = fetch_estimators_by_sim_id(sim_id=simulation.id) 

204 if len(estimators) > 0: 

205 logging.debug("Returning results from database") 

206 result_estimators = [] 

207 for estimator in estimators: 

208 pages: list[PageModel] = fetch_pages_by_estimator_id(est_id=estimator.id) 

209 estimator_dict = { 

210 "metadata": estimator.data, 

211 "name": estimator.name, 

212 "pages": [page.data for page in pages] 

213 } 

214 result_estimators.append(estimator_dict) 

215 return yaptide_response(message=f"Results for job: {job_id}", 

216 code=200, 

217 content={"estimators": result_estimators}) 

218 

219 result: dict = get_job_results(job_id=job_id) 

220 if "estimators" not in result: 

221 logging.debug("Results for job %s are unavailable", job_id) 

222 return yaptide_response(message="Results are unavailable", code=404, content=result) 

223 

224 for estimator_dict in result["estimators"]: 

225 estimator = EstimatorModel(name=estimator_dict["name"], simulation_id=simulation.id) 

226 estimator.data = estimator_dict["metadata"] 

227 add_object_to_db(estimator) 

228 for page_dict in estimator_dict["pages"]: 

229 page = PageModel(estimator_id=estimator.id, 

230 page_number=int(page_dict["metadata"]["page_number"]), 

231 page_dimension=int(page_dict['dimensions']), 

232 page_name=str(page_dict["metadata"]["name"])) 

233 page.data = page_dict 

234 add_object_to_db(page, False) 

235 make_commit_to_db() 

236 

237 logging.debug("Returning results from Celery") 

238 return yaptide_response(message=f"Results for job: {job_id}, results from Celery", code=200, content=result)