Coverage for yaptide/persistence/db_methods.py: 71%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import logging 

2from typing import Optional, Union 

3 

4from sqlalchemy.orm import with_polymorphic 

5 

6from yaptide.persistence.database import db 

7from yaptide.persistence.models import (BatchSimulationModel, BatchTaskModel, CelerySimulationModel, CeleryTaskModel, 

8 ClusterModel, EstimatorModel, InputModel, KeycloakUserModel, LogfilesModel, 

9 PageModel, SimulationModel, TaskModel, UserModel, YaptideUserModel) 

10 

11 

12def add_object_to_db(obj: db.Model, make_commit: bool = True) -> None: 

13 """Adds object to database and makes commit""" 

14 db.session.add(obj) 

15 if make_commit: 

16 make_commit_to_db() 

17 

18 

19def delete_object_from_db(obj: db.Model, make_commit: bool = True) -> None: 

20 """Deletes object from database and makes commit""" 

21 db.session.delete(obj) 

22 if make_commit: 

23 make_commit_to_db() 

24 

25 

26def make_commit_to_db(): 

27 """Makes commit""" 

28 db.session.commit() 

29 

30 

31def fetch_user_by_id(user_id: int) -> Union[KeycloakUserModel, YaptideUserModel]: 

32 """Fetches user by id""" 

33 UserPoly = with_polymorphic(UserModel, [YaptideUserModel, KeycloakUserModel]) 

34 user = db.session.query(UserPoly).filter_by(id=user_id).first() 

35 return user 

36 

37 

38def fetch_yaptide_user_by_username(username: str) -> YaptideUserModel: 

39 """Fetches user by username""" 

40 user = db.session.query(YaptideUserModel).filter_by(username=username).first() 

41 return user 

42 

43 

44def fetch_keycloak_user_by_username(username: str) -> KeycloakUserModel: 

45 """Fetches user by username""" 

46 user = db.session.query(KeycloakUserModel).filter_by(username=username).first() 

47 return user 

48 

49 

50def fetch_simulation_by_job_id(job_id: str) -> Union[BatchSimulationModel, CelerySimulationModel]: 

51 """Fetches simulation by job id""" 

52 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

53 simulation = db.session.query(SimulationPoly).filter_by(job_id=job_id).first() 

54 return simulation 

55 

56 

57def fetch_simulation_id_by_job_id(job_id: str) -> Optional[int]: 

58 """Fetches simulation_id by job_id for both Celery and Batch simulations. 

59 Returns simulation_id if simulation exists, 

60 or None if no simulation is found. 

61 """ 

62 simulation_id = db.session.query(SimulationModel.id).filter_by(job_id=job_id).first() 

63 return simulation_id[0] if simulation_id else None 

64 

65 

66def fetch_celery_simulation_by_job_id(job_id: str) -> CelerySimulationModel: 

67 """Fetches celery simulation by job id""" 

68 simulation = db.session.query(CelerySimulationModel).filter_by(job_id=job_id).first() 

69 return simulation 

70 

71 

72def fetch_batch_simulation_by_job_id(job_id: str) -> BatchSimulationModel: 

73 """Fetches batch simulation by job id""" 

74 simulation = db.session.query(BatchSimulationModel).filter_by(job_id=job_id).first() 

75 return simulation 

76 

77 

78def fetch_simulation_by_sim_id(sim_id: int) -> Union[BatchSimulationModel, CelerySimulationModel]: 

79 """Fetches simulation by sim id""" 

80 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

81 simulation = db.session.query(SimulationPoly).filter_by(id=sim_id).first() 

82 return simulation 

83 

84 

85def fetch_simulations_by_user_id(user_id: int) -> Union[list[BatchSimulationModel], list[CelerySimulationModel]]: 

86 """Fetches simulations by user id""" 

87 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

88 simulations = db.session.query(SimulationPoly).filter_by(user_id=user_id).all() 

89 return simulations 

90 

91 

92def fetch_task_by_sim_id_and_task_id(sim_id: int, task_id: str) -> Union[BatchTaskModel, CeleryTaskModel]: 

93 """Fetches task by simulation id and task id""" 

94 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

95 task = db.session.query(TaskPoly).filter_by(simulation_id=sim_id, task_id=task_id).first() 

96 return task 

97 

98 

99def fetch_tasks_by_sim_id(sim_id: int) -> Union[list[BatchTaskModel], list[CeleryTaskModel]]: 

100 """Fetches tasks by simulation id""" 

101 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

102 tasks = db.session.query(TaskPoly).filter_by(simulation_id=sim_id).all() 

103 return tasks 

104 

105 

106def fetch_celery_tasks_by_sim_id(sim_id: int) -> list[CeleryTaskModel]: 

107 """Fetches celery tasks by simulation""" 

108 tasks = db.session.query(CeleryTaskModel).filter_by(simulation_id=sim_id).all() 

109 return tasks 

110 

111 

112def fetch_batch_tasks_by_sim_id(sim_id: int) -> list[BatchTaskModel]: 

113 """Fetches batch tasks by simulation""" 

114 tasks = db.session.query(BatchTaskModel).filter_by(simulation_id=sim_id).all() 

115 return tasks 

116 

117 

118def fetch_estimators_by_sim_id(sim_id: int) -> list[EstimatorModel]: 

119 """Fetches estimators by simulation id""" 

120 estimators = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id).all() 

121 return estimators 

122 

123 

124def fetch_estimator_names_by_job_id(job_id: int) -> Optional[list[str]]: 

125 """Fetches estimators names by job id 

126 Returns a list of estimator names if the simulation exists, 

127 or None if no simulation is found for the provided job ID. 

128 """ 

129 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id) 

130 if not simulation_id: 

131 return None 

132 estimator_names_tuples = db.session.query(EstimatorModel.name).filter_by(simulation_id=simulation_id).all() 

133 estimator_names = [name for (name, ) in estimator_names_tuples] 

134 return estimator_names 

135 

136 

137def fetch_estimator_by_sim_id_and_est_name(sim_id: int, est_name: str) -> EstimatorModel: 

138 """Fetches estimator by simulation id and estimator name""" 

139 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, name=est_name).first() 

140 return estimator 

141 

142 

143def fetch_pages_by_estimator_id(est_id: int) -> list[PageModel]: 

144 """Fetches pages by estimator id""" 

145 pages = db.session.query(PageModel).filter_by(estimator_id=est_id).all() 

146 return pages 

147 

148 

149def fetch_page_by_est_id_and_page_number(est_id: int, page_number: int) -> PageModel: 

150 """Fetches page by estimator id and page number""" 

151 page = db.session.query(PageModel).filter_by(estimator_id=est_id, page_number=page_number).first() 

152 return page 

153 

154 

155def fetch_all_clusters() -> list[ClusterModel]: 

156 """Fetches all clusters""" 

157 clusters = db.session.query(ClusterModel).all() 

158 return clusters 

159 

160 

161def fetch_cluster_by_id(cluster_id: int) -> ClusterModel: 

162 """Fetches cluster by id""" 

163 cluster = db.session.query(ClusterModel).filter_by(id=cluster_id).first() 

164 return cluster 

165 

166 

167def fetch_input_by_sim_id(sim_id: int) -> InputModel: 

168 """Fetches input by simulation id""" 

169 input_model = db.session.query(InputModel).filter_by(simulation_id=sim_id).first() 

170 return input_model 

171 

172 

173def fetch_logfiles_by_sim_id(sim_id: int) -> LogfilesModel: 

174 """Fetches logfiles by simulation id""" 

175 logfiles = db.session.query(LogfilesModel).filter_by(simulation_id=sim_id).first() 

176 return logfiles 

177 

178 

179def update_task_state(task: Union[BatchTaskModel, CeleryTaskModel], update_dict: dict) -> None: 

180 """Updates task state and makes commit""" 

181 task.update_state(update_dict) 

182 db.session.commit() 

183 

184 

185def update_simulation_state(simulation: Union[BatchSimulationModel, CelerySimulationModel], update_dict: dict) -> None: 

186 """Updates simulation state and makes commit""" 

187 if simulation.update_state(update_dict): 

188 db.session.commit() 

189 else: 

190 logging.warning("Simulation state not updated, skipping commit")