Coverage for yaptide/persistence/db_methods.py: 70%
111 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import logging
2from typing import Optional, Union
4from sqlalchemy import and_
5from sqlalchemy.orm import with_polymorphic
7from yaptide.persistence.database import db
8from yaptide.persistence.models import (BatchSimulationModel, BatchTaskModel, CelerySimulationModel, CeleryTaskModel,
9 ClusterModel, EstimatorModel, InputModel, KeycloakUserModel, LogfilesModel,
10 PageModel, SimulationModel, TaskModel, UserModel, YaptideUserModel)
13def add_object_to_db(obj: db.Model, make_commit: bool = True) -> None:
14 """Adds object to database and makes commit"""
15 db.session.add(obj)
16 if make_commit:
17 make_commit_to_db()
20def delete_object_from_db(obj: db.Model, make_commit: bool = True) -> None:
21 """Deletes object from database and makes commit"""
22 db.session.delete(obj)
23 if make_commit:
24 make_commit_to_db()
27def make_commit_to_db():
28 """Makes commit"""
29 db.session.commit()
32def fetch_user_by_id(user_id: int) -> Union[KeycloakUserModel, YaptideUserModel]:
33 """Fetches user by id"""
34 UserPoly = with_polymorphic(UserModel, [YaptideUserModel, KeycloakUserModel])
35 user = db.session.query(UserPoly).filter_by(id=user_id).first()
36 return user
39def fetch_yaptide_user_by_username(username: str) -> YaptideUserModel:
40 """Fetches user by username"""
41 user = db.session.query(YaptideUserModel).filter_by(username=username).first()
42 return user
45def fetch_keycloak_user_by_username(username: str) -> KeycloakUserModel:
46 """Fetches user by username"""
47 user = db.session.query(KeycloakUserModel).filter_by(username=username).first()
48 return user
51def fetch_simulation_by_job_id(job_id: str) -> Union[BatchSimulationModel, CelerySimulationModel]:
52 """Fetches simulation by job id"""
53 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
54 simulation = db.session.query(SimulationPoly).filter_by(job_id=job_id).first()
55 return simulation
58def fetch_simulation_id_by_job_id(job_id: str) -> Optional[int]:
59 """Fetches simulation_id by job_id for both Celery and Batch simulations.
60 Returns simulation_id if simulation exists,
61 or None if no simulation is found.
62 """
63 simulation_id = db.session.query(SimulationModel.id).filter_by(job_id=job_id).first()
64 return simulation_id[0] if simulation_id else None
67def fetch_celery_simulation_by_job_id(job_id: str) -> CelerySimulationModel:
68 """Fetches celery simulation by job id"""
69 simulation = db.session.query(CelerySimulationModel).filter_by(job_id=job_id).first()
70 return simulation
73def fetch_batch_simulation_by_job_id(job_id: str) -> BatchSimulationModel:
74 """Fetches batch simulation by job id"""
75 simulation = db.session.query(BatchSimulationModel).filter_by(job_id=job_id).first()
76 return simulation
79def fetch_simulation_by_sim_id(sim_id: int) -> Union[BatchSimulationModel, CelerySimulationModel]:
80 """Fetches simulation by sim id"""
81 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
82 simulation = db.session.query(SimulationPoly).filter_by(id=sim_id).first()
83 return simulation
86def fetch_simulations_by_user_id(user_id: int) -> Union[list[BatchSimulationModel], list[CelerySimulationModel]]:
87 """Fetches simulations by user id"""
88 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
89 simulations = db.session.query(SimulationPoly).filter_by(user_id=user_id).all()
90 return simulations
93def fetch_task_by_sim_id_and_task_id(sim_id: int, task_id: str) -> Union[BatchTaskModel, CeleryTaskModel]:
94 """Fetches task by simulation id and task id"""
95 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel])
96 task = db.session.query(TaskPoly).filter_by(simulation_id=sim_id, task_id=task_id).first()
97 return task
100def fetch_tasks_by_sim_id(sim_id: int) -> Union[list[BatchTaskModel], list[CeleryTaskModel]]:
101 """Fetches tasks by simulation id"""
102 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel])
103 tasks = db.session.query(TaskPoly).filter_by(simulation_id=sim_id).all()
104 return tasks
107def fetch_celery_tasks_by_sim_id(sim_id: int) -> list[CeleryTaskModel]:
108 """Fetches celery tasks by simulation"""
109 tasks = db.session.query(CeleryTaskModel).filter_by(simulation_id=sim_id).all()
110 return tasks
113def fetch_batch_tasks_by_sim_id(sim_id: int) -> list[BatchTaskModel]:
114 """Fetches batch tasks by simulation"""
115 tasks = db.session.query(BatchTaskModel).filter_by(simulation_id=sim_id).all()
116 return tasks
119def fetch_estimators_by_sim_id(sim_id: int) -> list[EstimatorModel]:
120 """Fetches estimators by simulation id"""
121 estimators = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id).all()
122 return estimators
125def fetch_estimator_names_by_job_id(job_id: int) -> Optional[list[str]]:
126 """Fetches estimators names by job id
127 Returns a list of estimator names if the simulation exists,
128 or None if no simulation is found for the provided job ID.
129 """
130 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id)
131 if not simulation_id:
132 return None
133 estimator_names_tuples = db.session.query(EstimatorModel.name).filter_by(simulation_id=simulation_id).all()
134 estimator_names = [name for (name, ) in estimator_names_tuples]
135 return estimator_names
138def fetch_estimator_by_sim_id_and_est_name(sim_id: int, est_name: str) -> EstimatorModel:
139 """Fetches estimator by simulation id and estimator name"""
140 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, name=est_name).first()
141 return estimator
144def fetch_estimator_by_sim_id_and_file_name(sim_id: int, file_name: str) -> EstimatorModel:
145 """Fetches estimator by simulation id and estimator name"""
146 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, file_name=file_name).first()
147 return estimator
150def fetch_estimator_id_by_sim_id_and_est_name(sim_id: int, est_name: str) -> Optional[int]:
151 """Fetches estimator_id by simulation id and estimator name"""
152 estimator_id = db.session.query(EstimatorModel.id).filter_by(simulation_id=sim_id, name=est_name).first()
153 return estimator_id[0] if estimator_id else None
156def fetch_pages_by_estimator_id(est_id: int) -> list[PageModel]:
157 """Fetches pages by estimator id"""
158 pages = db.session.query(PageModel).filter_by(estimator_id=est_id).all()
159 return pages
162def fetch_page_by_est_id_and_page_number(est_id: int, page_number: int) -> PageModel:
163 """Fetches page by estimator id and page number"""
164 page = db.session.query(PageModel).filter_by(estimator_id=est_id, page_number=page_number).first()
165 return page
168def fetch_pages_by_est_id_and_page_numbers(est_id: int, page_numbers: list) -> PageModel:
169 """Fetches page by estimator id and page number"""
170 pages = db.session.query(PageModel).filter(
171 and_(PageModel.estimator_id == est_id, PageModel.page_number.in_(page_numbers))).all()
172 return pages
175def fetch_pages_metadata_by_est_id(est_id: str) -> EstimatorModel:
176 """Fetches estimator by simulation id and estimator name"""
177 pages_metadata = db.session.query(PageModel.page_number, PageModel.page_name,
178 PageModel.page_dimension).filter_by(estimator_id=est_id).all()
179 return pages_metadata
182def fetch_all_clusters() -> list[ClusterModel]:
183 """Fetches all clusters"""
184 clusters = db.session.query(ClusterModel).all()
185 return clusters
188def fetch_cluster_by_id(cluster_id: int) -> ClusterModel:
189 """Fetches cluster by id"""
190 cluster = db.session.query(ClusterModel).filter_by(id=cluster_id).first()
191 return cluster
194def fetch_input_by_sim_id(sim_id: int) -> InputModel:
195 """Fetches input by simulation id"""
196 input_model = db.session.query(InputModel).filter_by(simulation_id=sim_id).first()
197 return input_model
200def fetch_logfiles_by_sim_id(sim_id: int) -> LogfilesModel:
201 """Fetches logfiles by simulation id"""
202 logfiles = db.session.query(LogfilesModel).filter_by(simulation_id=sim_id).first()
203 return logfiles
206def update_task_state(task: Union[BatchTaskModel, CeleryTaskModel], update_dict: dict) -> None:
207 """Updates task state and makes commit"""
208 task.update_state(update_dict)
209 db.session.commit()
212def update_simulation_state(simulation: Union[BatchSimulationModel, CelerySimulationModel], update_dict: dict) -> None:
213 """Updates simulation state and makes commit"""
214 if simulation.update_state(update_dict):
215 db.session.commit()
216 else:
217 logging.warning("Simulation state not updated, skipping commit")