Coverage for yaptide/celery/utils/manage_tasks.py: 33%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import logging 

2 

3from celery import chord, group 

4from celery.result import AsyncResult 

5 

6from yaptide.celery.tasks import merge_results, run_single_simulation 

7from yaptide.celery.worker import celery_app 

8from yaptide.utils.enums import EntityState 

9 

10 

11def run_job(files_dict: dict, update_key: str, simulation_id: int, ntasks: int, sim_type: str = 'shieldhit') -> str: 

12 """Runs asynchronous simulation job""" 

13 logging.debug("Starting run_simulation task for %d tasks", ntasks) 

14 logging.debug("Simulation id: %d", simulation_id) 

15 logging.debug("Update key: %s", update_key) 

16 map_group = group([ 

17 run_single_simulation.s( 

18 files_dict=files_dict, # simulation input, keys: filenames, values: file contents 

19 task_id=i, 

20 update_key=update_key, 

21 simulation_id=simulation_id, 

22 sim_type=sim_type) for i in range(ntasks) 

23 ]) 

24 

25 workflow = chord(map_group, merge_results.s()) 

26 

27 job: AsyncResult = workflow.delay() 

28 

29 return job.id 

30 

31 

32def get_task_status(job_id: str, state_key: str) -> dict: 

33 """Gets status of each task in the workflow""" 

34 job = AsyncResult(id=job_id, app=celery_app) 

35 job_state: str = translate_celery_state_naming(job.state) 

36 

37 # we still need to convert string to enum and operate later on Enum 

38 result = {state_key: job_state} 

39 if job_state == EntityState.FAILED.value: 

40 result["message"] = str(job.info) 

41 if "end_time" in job.info: 

42 result["end_time"] = job.info["end_time"] 

43 return result 

44 

45 

46def get_job_status(merge_id: str, celery_ids: list[str]) -> dict: 

47 """ 

48 Returns simulation state, results are not returned here 

49 Simulation may consist of multiple tasks, so we need to check all of them 

50 """ 

51 result = { 

52 "merge": get_task_status(merge_id, "job_state"), 

53 "tasks": [get_task_status(job_id, "task_state") for job_id in celery_ids] 

54 } 

55 

56 return result 

57 

58 

59def cancel_job(merge_id: str, celery_ids: list[str]) -> dict: 

60 """Cancels simulation""" 

61 

62 def cancel_task(job_id: str, state_key: str) -> dict: 

63 """Cancels (if possible) every task in the workflow""" 

64 job = AsyncResult(id=job_id, app=celery_app) 

65 job_state: str = translate_celery_state_naming(job.state) 

66 

67 if job_state in [EntityState.CANCELED.value, EntityState.COMPLETED.value, EntityState.FAILED.value]: 

68 logging.warning("Cannot cancel job %s which is already %s", job_id, job_state) 

69 return {state_key: job_state, "message": f"Job already {job_state}"} 

70 try: 

71 celery_app.control.revoke(job_id, terminate=True, signal="SIGINT") 

72 except Exception as e: # skipcq: PYL-W0703 

73 logging.error("Cannot cancel job %s, due to %s", job_id, e) 

74 return { 

75 state_key: job_state, 

76 "message": f"Cannot cancel job {job_id}, leaving at current state {job_state}" 

77 } 

78 

79 return {state_key: EntityState.CANCELED.value, "message": f"Job {job_id} canceled"} 

80 

81 result = { 

82 "merge": cancel_task(merge_id, "job_state"), 

83 "tasks": [cancel_task(job_id, "task_state") for job_id in celery_ids] 

84 } 

85 return result 

86 

87 

88def get_job_results(job_id: str) -> dict: 

89 """Returns simulation results""" 

90 job = AsyncResult(id=job_id, app=celery_app) 

91 if "result" not in job.info: 

92 return {} 

93 return job.info.get("result") 

94 

95 

96def translate_celery_state_naming(job_state: str) -> str: 

97 """Function translating celery states' names to ones used in YAPTIDE""" 

98 if job_state in ["RECEIVED", "RETRY"]: 

99 return EntityState.PENDING.value 

100 if job_state in ["PROGRESS", "STARTED"]: 

101 return EntityState.RUNNING.value 

102 if job_state in ["FAILURE"]: 

103 return EntityState.FAILED.value 

104 if job_state in ["REVOKED"]: 

105 return EntityState.CANCELED.value 

106 if job_state in ["SUCCESS"]: 

107 return EntityState.COMPLETED.value 

108 # Others are the same 

109 return job_state