Coverage for yaptide/persistence/db_methods.py: 70%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import logging 

2from typing import Optional, Union 

3 

4from sqlalchemy import and_ 

5from sqlalchemy.orm import with_polymorphic 

6 

7from yaptide.persistence.database import db 

8from yaptide.persistence.models import (BatchSimulationModel, BatchTaskModel, CelerySimulationModel, CeleryTaskModel, 

9 ClusterModel, EstimatorModel, InputModel, KeycloakUserModel, LogfilesModel, 

10 PageModel, SimulationModel, TaskModel, UserModel, YaptideUserModel) 

11 

12 

13def add_object_to_db(obj: db.Model, make_commit: bool = True) -> None: 

14 """Adds object to database and makes commit""" 

15 db.session.add(obj) 

16 if make_commit: 

17 make_commit_to_db() 

18 

19 

20def delete_object_from_db(obj: db.Model, make_commit: bool = True) -> None: 

21 """Deletes object from database and makes commit""" 

22 db.session.delete(obj) 

23 if make_commit: 

24 make_commit_to_db() 

25 

26 

27def make_commit_to_db(): 

28 """Makes commit""" 

29 db.session.commit() 

30 

31 

32def fetch_user_by_id(user_id: int) -> Union[KeycloakUserModel, YaptideUserModel]: 

33 """Fetches user by id""" 

34 UserPoly = with_polymorphic(UserModel, [YaptideUserModel, KeycloakUserModel]) 

35 user = db.session.query(UserPoly).filter_by(id=user_id).first() 

36 return user 

37 

38 

39def fetch_yaptide_user_by_username(username: str) -> YaptideUserModel: 

40 """Fetches user by username""" 

41 user = db.session.query(YaptideUserModel).filter_by(username=username).first() 

42 return user 

43 

44 

45def fetch_keycloak_user_by_username(username: str) -> KeycloakUserModel: 

46 """Fetches user by username""" 

47 user = db.session.query(KeycloakUserModel).filter_by(username=username).first() 

48 return user 

49 

50 

51def fetch_simulation_by_job_id(job_id: str) -> Union[BatchSimulationModel, CelerySimulationModel]: 

52 """Fetches simulation by job id""" 

53 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

54 simulation = db.session.query(SimulationPoly).filter_by(job_id=job_id).first() 

55 return simulation 

56 

57 

58def fetch_simulation_id_by_job_id(job_id: str) -> Optional[int]: 

59 """Fetches simulation_id by job_id for both Celery and Batch simulations. 

60 Returns simulation_id if simulation exists, 

61 or None if no simulation is found. 

62 """ 

63 simulation_id = db.session.query(SimulationModel.id).filter_by(job_id=job_id).first() 

64 return simulation_id[0] if simulation_id else None 

65 

66 

67def fetch_celery_simulation_by_job_id(job_id: str) -> CelerySimulationModel: 

68 """Fetches celery simulation by job id""" 

69 simulation = db.session.query(CelerySimulationModel).filter_by(job_id=job_id).first() 

70 return simulation 

71 

72 

73def fetch_batch_simulation_by_job_id(job_id: str) -> BatchSimulationModel: 

74 """Fetches batch simulation by job id""" 

75 simulation = db.session.query(BatchSimulationModel).filter_by(job_id=job_id).first() 

76 return simulation 

77 

78 

79def fetch_simulation_by_sim_id(sim_id: int) -> Union[BatchSimulationModel, CelerySimulationModel]: 

80 """Fetches simulation by sim id""" 

81 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

82 simulation = db.session.query(SimulationPoly).filter_by(id=sim_id).first() 

83 return simulation 

84 

85 

86def fetch_simulations_by_user_id(user_id: int) -> Union[list[BatchSimulationModel], list[CelerySimulationModel]]: 

87 """Fetches simulations by user id""" 

88 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel]) 

89 simulations = db.session.query(SimulationPoly).filter_by(user_id=user_id).all() 

90 return simulations 

91 

92 

93def fetch_task_by_sim_id_and_task_id(sim_id: int, task_id: str) -> Union[BatchTaskModel, CeleryTaskModel]: 

94 """Fetches task by simulation id and task id""" 

95 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

96 task = db.session.query(TaskPoly).filter_by(simulation_id=sim_id, task_id=task_id).first() 

97 return task 

98 

99 

100def fetch_tasks_by_sim_id(sim_id: int) -> Union[list[BatchTaskModel], list[CeleryTaskModel]]: 

101 """Fetches tasks by simulation id""" 

102 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel]) 

103 tasks = db.session.query(TaskPoly).filter_by(simulation_id=sim_id).all() 

104 return tasks 

105 

106 

107def fetch_celery_tasks_by_sim_id(sim_id: int) -> list[CeleryTaskModel]: 

108 """Fetches celery tasks by simulation""" 

109 tasks = db.session.query(CeleryTaskModel).filter_by(simulation_id=sim_id).all() 

110 return tasks 

111 

112 

113def fetch_batch_tasks_by_sim_id(sim_id: int) -> list[BatchTaskModel]: 

114 """Fetches batch tasks by simulation""" 

115 tasks = db.session.query(BatchTaskModel).filter_by(simulation_id=sim_id).all() 

116 return tasks 

117 

118 

119def fetch_estimators_by_sim_id(sim_id: int) -> list[EstimatorModel]: 

120 """Fetches estimators by simulation id""" 

121 estimators = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id).all() 

122 return estimators 

123 

124 

125def fetch_estimator_names_by_job_id(job_id: int) -> Optional[list[str]]: 

126 """Fetches estimators names by job id 

127 Returns a list of estimator names if the simulation exists, 

128 or None if no simulation is found for the provided job ID. 

129 """ 

130 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id) 

131 if not simulation_id: 

132 return None 

133 estimator_names_tuples = db.session.query(EstimatorModel.name).filter_by(simulation_id=simulation_id).all() 

134 estimator_names = [name for (name, ) in estimator_names_tuples] 

135 return estimator_names 

136 

137 

138def fetch_estimator_by_sim_id_and_est_name(sim_id: int, est_name: str) -> EstimatorModel: 

139 """Fetches estimator by simulation id and estimator name""" 

140 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, name=est_name).first() 

141 return estimator 

142 

143 

144def fetch_estimator_by_sim_id_and_file_name(sim_id: int, file_name: str) -> EstimatorModel: 

145 """Fetches estimator by simulation id and estimator name""" 

146 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, file_name=file_name).first() 

147 return estimator 

148 

149 

150def fetch_estimator_id_by_sim_id_and_est_name(sim_id: int, est_name: str) -> Optional[int]: 

151 """Fetches estimator_id by simulation id and estimator name""" 

152 estimator_id = db.session.query(EstimatorModel.id).filter_by(simulation_id=sim_id, name=est_name).first() 

153 return estimator_id[0] if estimator_id else None 

154 

155 

156def fetch_pages_by_estimator_id(est_id: int) -> list[PageModel]: 

157 """Fetches pages by estimator id""" 

158 pages = db.session.query(PageModel).filter_by(estimator_id=est_id).all() 

159 return pages 

160 

161 

162def fetch_page_by_est_id_and_page_number(est_id: int, page_number: int) -> PageModel: 

163 """Fetches page by estimator id and page number""" 

164 page = db.session.query(PageModel).filter_by(estimator_id=est_id, page_number=page_number).first() 

165 return page 

166 

167 

168def fetch_pages_by_est_id_and_page_numbers(est_id: int, page_numbers: list) -> PageModel: 

169 """Fetches page by estimator id and page number""" 

170 pages = db.session.query(PageModel).filter( 

171 and_(PageModel.estimator_id == est_id, PageModel.page_number.in_(page_numbers))).all() 

172 return pages 

173 

174 

175def fetch_pages_metadata_by_est_id(est_id: str) -> EstimatorModel: 

176 """Fetches estimator by simulation id and estimator name""" 

177 pages_metadata = db.session.query(PageModel.page_number, PageModel.page_name, 

178 PageModel.page_dimension).filter_by(estimator_id=est_id).all() 

179 return pages_metadata 

180 

181 

182def fetch_all_clusters() -> list[ClusterModel]: 

183 """Fetches all clusters""" 

184 clusters = db.session.query(ClusterModel).all() 

185 return clusters 

186 

187 

188def fetch_cluster_by_id(cluster_id: int) -> ClusterModel: 

189 """Fetches cluster by id""" 

190 cluster = db.session.query(ClusterModel).filter_by(id=cluster_id).first() 

191 return cluster 

192 

193 

194def fetch_input_by_sim_id(sim_id: int) -> InputModel: 

195 """Fetches input by simulation id""" 

196 input_model = db.session.query(InputModel).filter_by(simulation_id=sim_id).first() 

197 return input_model 

198 

199 

200def fetch_logfiles_by_sim_id(sim_id: int) -> LogfilesModel: 

201 """Fetches logfiles by simulation id""" 

202 logfiles = db.session.query(LogfilesModel).filter_by(simulation_id=sim_id).first() 

203 return logfiles 

204 

205 

206def update_task_state(task: Union[BatchTaskModel, CeleryTaskModel], update_dict: dict) -> None: 

207 """Updates task state and makes commit""" 

208 task.update_state(update_dict) 

209 db.session.commit() 

210 

211 

212def update_simulation_state(simulation: Union[BatchSimulationModel, CelerySimulationModel], update_dict: dict) -> None: 

213 """Updates simulation state and makes commit""" 

214 if simulation.update_state(update_dict): 

215 db.session.commit() 

216 else: 

217 logging.warning("Simulation state not updated, skipping commit")