Coverage for yaptide/celery/utils/manage_tasks.py: 42%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import logging 

2 

3from celery import chain, chord, group 

4from celery.result import AsyncResult 

5 

6from yaptide.celery.tasks import merge_results, run_single_simulation, set_merging_queued_state 

7from yaptide.celery.simulation_worker import celery_app 

8from yaptide.utils.enums import EntityState 

9 

10 

11def run_job(files_dict: dict, 

12 update_key: str, 

13 simulation_id: int, 

14 ntasks: int, 

15 celery_ids: list, 

16 sim_type: str = 'shieldhit') -> str: 

17 """Runs asynchronous simulation job""" 

18 logging.debug("Starting run_simulation task for %d tasks", ntasks) 

19 logging.debug("Simulation id: %d", simulation_id) 

20 logging.debug("Update key: %s", update_key) 

21 map_group = group([ 

22 run_single_simulation.s( 

23 files_dict=files_dict, # simulation input, keys: filenames, values: file contents 

24 task_id=i, 

25 update_key=update_key, 

26 simulation_id=simulation_id, 

27 sim_type=sim_type).set(task_id=celery_ids[i]) for i in range(ntasks) 

28 ]) 

29 

30 # By setup of simulation_worker all tasks from yaptide.celery.tasks are directed to simulations queue 

31 # For tests to work: putting signature as second task in chord requires specifying queue 

32 workflow = chord( 

33 map_group, 

34 chain(set_merging_queued_state.s().set(queue="simulations"), 

35 merge_results.s().set(queue="simulations"))) 

36 job: AsyncResult = workflow.delay() 

37 

38 return job.id 

39 

40 

41def get_task_status(job_id: str, state_key: str) -> dict: 

42 """Gets status of each task in the workflow""" 

43 job = AsyncResult(id=job_id, app=celery_app) 

44 job_state: str = translate_celery_state_naming(job.state) 

45 

46 # we still need to convert string to enum and operate later on Enum 

47 result = {state_key: job_state} 

48 if job_state == EntityState.FAILED.value: 

49 result["message"] = str(job.info) 

50 if "end_time" in job.info: 

51 result["end_time"] = job.info["end_time"] 

52 return result 

53 

54 

55def get_job_status(merge_id: str, celery_ids: list[str]) -> dict: 

56 """ 

57 Returns simulation state, results are not returned here 

58 Simulation may consist of multiple tasks, so we need to check all of them 

59 """ 

60 result = { 

61 "merge": get_task_status(merge_id, "job_state"), 

62 "tasks": [get_task_status(job_id, "task_state") for job_id in celery_ids] 

63 } 

64 

65 return result 

66 

67 

68def get_job_results(job_id: str) -> dict: 

69 """Returns simulation results""" 

70 job = AsyncResult(id=job_id, app=celery_app) 

71 if "result" not in job.info: 

72 return {} 

73 return job.info.get("result") 

74 

75 

76def translate_celery_state_naming(job_state: str) -> str: 

77 """Function translating celery states' names to ones used in YAPTIDE""" 

78 if job_state in ["RECEIVED", "RETRY"]: 

79 return EntityState.PENDING.value 

80 if job_state in ["PROGRESS", "STARTED"]: 

81 return EntityState.RUNNING.value 

82 if job_state in ["FAILURE"]: 

83 return EntityState.FAILED.value 

84 if job_state in ["REVOKED"]: 

85 return EntityState.CANCELED.value 

86 if job_state in ["SUCCESS"]: 

87 return EntityState.COMPLETED.value 

88 # Others are the same 

89 return job_state