Coverage for yaptide/persistence/db_methods.py: 64%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import logging 

2from typing import Union 

3 

4from sqlalchemy.orm import with_polymorphic 

5 

6from yaptide.persistence.database import db 

7from yaptide.persistence.models import (BatchSimulationModel, BatchTaskModel, 

8 CelerySimulationModel, CeleryTaskModel, 

9 ClusterModel, EstimatorModel, 

10 InputModel, KeycloakUserModel, 

11 LogfilesModel, PageModel, 

12 SimulationModel, TaskModel, UserModel, 

13 YaptideUserModel) 

14 

15 

16def add_object_to_db(obj: db.Model, make_commit: bool = True) -> None: 

17 """Adds object to database and makes commit""" 

18 db.session.add(obj) 

19 if make_commit: 

20 make_commit_to_db() 

21 

22 

23def delete_object_from_db(obj: db.Model, make_commit: bool = True) -> None: 

24 """Deletes object from database and makes commit""" 

25 db.session.delete(obj) 

26 if make_commit: 

27 make_commit_to_db() 

28 

29 

30def make_commit_to_db(): 

31 """Makes commit""" 

32 db.session.commit() 

33 

34 

35def fetch_user_by_id(user_id: int) -> Union[KeycloakUserModel, YaptideUserModel]: 

36 """Fetches user by id""" 

37 UserPoly = with_polymorphic(UserModel, [YaptideUserModel, KeycloakUserModel]) 

38 user = db.session.query(UserPoly).filter_by(id=user_id).first() 

39 return user 

40 

41 

42def fetch_yaptide_user_by_username(username: str) -> YaptideUserModel: 

43 """Fetches user by username""" 

44 user = db.session.query(YaptideUserModel).filter_by(username=username).first() 

45 return user 

46 

47 

48def fetch_keycloak_user_by_username(username: str) -> KeycloakUserModel: 

49 """Fetches user by username""" 

50 user = db.session.query(KeycloakUserModel).filter_by(username=username).first() 

51 return user 

52 

53 

54def fetch_simulation_by_job_id(job_id: str) -> Union[BatchSimulationModel, CelerySimulationModel]: 

55 """Fetches simulation by job id""" 

56 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

57 simulation = db.session.query(SimulationPoly).filter_by(job_id=job_id).first() 

58 return simulation 

59 

60 

61def fetch_celery_simulation_by_job_id(job_id: str) -> CelerySimulationModel: 

62 """Fetches celery simulation by job id""" 

63 simulation = db.session.query(CelerySimulationModel).filter_by(job_id=job_id).first() 

64 return simulation 

65 

66 

67def fetch_batch_simulation_by_job_id(job_id: str) -> BatchSimulationModel: 

68 """Fetches batch simulation by job id""" 

69 simulation = db.session.query(BatchSimulationModel).filter_by(job_id=job_id).first() 

70 return simulation 

71 

72 

73def fetch_simulation_by_sim_id(sim_id: int) -> Union[BatchSimulationModel, CelerySimulationModel]: 

74 """Fetches simulation by sim id""" 

75 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

76 simulation = db.session.query(SimulationPoly).filter_by(id=sim_id).first() 

77 return simulation 

78 

79 

80def fetch_simulations_by_user_id(user_id: int) -> Union[list[BatchSimulationModel], list[CelerySimulationModel]]: 

81 """Fetches simulations by user id""" 

82 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

83 simulations = db.session.query(SimulationPoly).filter_by(user_id=user_id).all() 

84 return simulations 

85 

86 

87def fetch_task_by_sim_id_and_task_id(sim_id: int, task_id: str) -> Union[BatchTaskModel, CeleryTaskModel]: 

88 """Fetches task by simulation id and task id""" 

89 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

90 task = db.session.query(TaskPoly).filter_by(simulation_id=sim_id, task_id=task_id).first() 

91 return task 

92 

93 

94def fetch_tasks_by_sim_id(sim_id: int) -> Union[list[BatchTaskModel], list[CeleryTaskModel]]: 

95 """Fetches tasks by simulation id""" 

96 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

97 tasks = db.session.query(TaskPoly).filter_by(simulation_id=sim_id).all() 

98 return tasks 

99 

100 

101def fetch_celery_tasks_by_sim_id(sim_id: int) -> list[CeleryTaskModel]: 

102 """Fetches celery tasks by simulation""" 

103 tasks = db.session.query(CeleryTaskModel).filter_by(simulation_id=sim_id).all() 

104 return tasks 

105 

106 

107def fetch_batch_tasks_by_sim_id(sim_id: int) -> list[BatchTaskModel]: 

108 """Fetches batch tasks by simulation""" 

109 tasks = db.session.query(BatchTaskModel).filter_by(simulation_id=sim_id).all() 

110 return tasks 

111 

112 

113def fetch_estimators_by_sim_id(sim_id: int) -> list[EstimatorModel]: 

114 """Fetches estimators by simulation id""" 

115 estimators = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id).all() 

116 return estimators 

117 

118 

119def fetch_estimator_by_sim_id_and_est_name(sim_id: int, est_name: str) -> EstimatorModel: 

120 """Fetches estimator by simulation id and estimator name""" 

121 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, name=est_name).first() 

122 return estimator 

123 

124 

125def fetch_pages_by_estimator_id(est_id: int) -> list[PageModel]: 

126 """Fetches pages by estimator id""" 

127 pages = db.session.query(PageModel).filter_by(estimator_id=est_id).all() 

128 return pages 

129 

130 

131def fetch_page_by_est_id_and_page_number(est_id: int, page_number: int) -> PageModel: 

132 """Fetches page by estimator id and page number""" 

133 page = db.session.query(PageModel).filter_by(estimator_id=est_id, page_number=page_number).first() 

134 return page 

135 

136 

137def fetch_all_clusters() -> list[ClusterModel]: 

138 """Fetches all clusters""" 

139 clusters = db.session.query(ClusterModel).all() 

140 return clusters 

141 

142 

143def fetch_cluster_by_id(cluster_id: int) -> ClusterModel: 

144 """Fetches cluster by id""" 

145 cluster = db.session.query(ClusterModel).filter_by(id=cluster_id).first() 

146 return cluster 

147 

148 

149def fetch_input_by_sim_id(sim_id: int) -> InputModel: 

150 """Fetches input by simulation id""" 

151 input_model = db.session.query(InputModel).filter_by(simulation_id=sim_id).first() 

152 return input_model 

153 

154 

155def fetch_logfiles_by_sim_id(sim_id: int) -> LogfilesModel: 

156 """Fetches logfiles by simulation id""" 

157 logfiles = db.session.query(LogfilesModel).filter_by(simulation_id=sim_id).first() 

158 return logfiles 

159 

160 

161def update_task_state(task: Union[BatchTaskModel, CeleryTaskModel], update_dict: dict) -> None: 

162 """Updates task state and makes commit""" 

163 task.update_state(update_dict) 

164 db.session.commit() 

165 

166 

167def update_simulation_state(simulation: Union[BatchSimulationModel, CelerySimulationModel], update_dict: dict) -> None: 

168 """Updates simulation state and makes commit""" 

169 if simulation.update_state(update_dict): 

170 db.session.commit() 

171 else: 

172 logging.warning("Simulation state not updated, skipping commit")