Coverage for yaptide/persistence/db_methods.py: 71%
98 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
1import logging
2from typing import Optional, Union
4from sqlalchemy.orm import with_polymorphic
6from yaptide.persistence.database import db
7from yaptide.persistence.models import (BatchSimulationModel, BatchTaskModel, CelerySimulationModel, CeleryTaskModel,
8 ClusterModel, EstimatorModel, InputModel, KeycloakUserModel, LogfilesModel,
9 PageModel, SimulationModel, TaskModel, UserModel, YaptideUserModel)
12def add_object_to_db(obj: db.Model, make_commit: bool = True) -> None:
13 """Adds object to database and makes commit"""
14 db.session.add(obj)
15 if make_commit:
16 make_commit_to_db()
19def delete_object_from_db(obj: db.Model, make_commit: bool = True) -> None:
20 """Deletes object from database and makes commit"""
21 db.session.delete(obj)
22 if make_commit:
23 make_commit_to_db()
26def make_commit_to_db():
27 """Makes commit"""
28 db.session.commit()
31def fetch_user_by_id(user_id: int) -> Union[KeycloakUserModel, YaptideUserModel]:
32 """Fetches user by id"""
33 UserPoly = with_polymorphic(UserModel, [YaptideUserModel, KeycloakUserModel])
34 user = db.session.query(UserPoly).filter_by(id=user_id).first()
35 return user
38def fetch_yaptide_user_by_username(username: str) -> YaptideUserModel:
39 """Fetches user by username"""
40 user = db.session.query(YaptideUserModel).filter_by(username=username).first()
41 return user
44def fetch_keycloak_user_by_username(username: str) -> KeycloakUserModel:
45 """Fetches user by username"""
46 user = db.session.query(KeycloakUserModel).filter_by(username=username).first()
47 return user
50def fetch_simulation_by_job_id(job_id: str) -> Union[BatchSimulationModel, CelerySimulationModel]:
51 """Fetches simulation by job id"""
52 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
53 simulation = db.session.query(SimulationPoly).filter_by(job_id=job_id).first()
54 return simulation
57def fetch_simulation_id_by_job_id(job_id: str) -> Optional[int]:
58 """Fetches simulation_id by job_id for both Celery and Batch simulations.
59 Returns simulation_id if simulation exists,
60 or None if no simulation is found.
61 """
62 simulation_id = db.session.query(SimulationModel.id).filter_by(job_id=job_id).first()
63 return simulation_id[0] if simulation_id else None
66def fetch_celery_simulation_by_job_id(job_id: str) -> CelerySimulationModel:
67 """Fetches celery simulation by job id"""
68 simulation = db.session.query(CelerySimulationModel).filter_by(job_id=job_id).first()
69 return simulation
72def fetch_batch_simulation_by_job_id(job_id: str) -> BatchSimulationModel:
73 """Fetches batch simulation by job id"""
74 simulation = db.session.query(BatchSimulationModel).filter_by(job_id=job_id).first()
75 return simulation
78def fetch_simulation_by_sim_id(sim_id: int) -> Union[BatchSimulationModel, CelerySimulationModel]:
79 """Fetches simulation by sim id"""
80 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
81 simulation = db.session.query(SimulationPoly).filter_by(id=sim_id).first()
82 return simulation
85def fetch_simulations_by_user_id(user_id: int) -> Union[list[BatchSimulationModel], list[CelerySimulationModel]]:
86 """Fetches simulations by user id"""
87 SimulationPoly = with_polymorphic(SimulationModel, [BatchSimulationModel, CelerySimulationModel])
88 simulations = db.session.query(SimulationPoly).filter_by(user_id=user_id).all()
89 return simulations
92def fetch_task_by_sim_id_and_task_id(sim_id: int, task_id: str) -> Union[BatchTaskModel, CeleryTaskModel]:
93 """Fetches task by simulation id and task id"""
94 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel])
95 task = db.session.query(TaskPoly).filter_by(simulation_id=sim_id, task_id=task_id).first()
96 return task
99def fetch_tasks_by_sim_id(sim_id: int) -> Union[list[BatchTaskModel], list[CeleryTaskModel]]:
100 """Fetches tasks by simulation id"""
101 TaskPoly = with_polymorphic(TaskModel, [BatchTaskModel, CeleryTaskModel])
102 tasks = db.session.query(TaskPoly).filter_by(simulation_id=sim_id).all()
103 return tasks
106def fetch_celery_tasks_by_sim_id(sim_id: int) -> list[CeleryTaskModel]:
107 """Fetches celery tasks by simulation"""
108 tasks = db.session.query(CeleryTaskModel).filter_by(simulation_id=sim_id).all()
109 return tasks
112def fetch_batch_tasks_by_sim_id(sim_id: int) -> list[BatchTaskModel]:
113 """Fetches batch tasks by simulation"""
114 tasks = db.session.query(BatchTaskModel).filter_by(simulation_id=sim_id).all()
115 return tasks
118def fetch_estimators_by_sim_id(sim_id: int) -> list[EstimatorModel]:
119 """Fetches estimators by simulation id"""
120 estimators = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id).all()
121 return estimators
124def fetch_estimator_names_by_job_id(job_id: int) -> Optional[list[str]]:
125 """Fetches estimators names by job id
126 Returns a list of estimator names if the simulation exists,
127 or None if no simulation is found for the provided job ID.
128 """
129 simulation_id = fetch_simulation_id_by_job_id(job_id=job_id)
130 if not simulation_id:
131 return None
132 estimator_names_tuples = db.session.query(EstimatorModel.name).filter_by(simulation_id=simulation_id).all()
133 estimator_names = [name for (name, ) in estimator_names_tuples]
134 return estimator_names
137def fetch_estimator_by_sim_id_and_est_name(sim_id: int, est_name: str) -> EstimatorModel:
138 """Fetches estimator by simulation id and estimator name"""
139 estimator = db.session.query(EstimatorModel).filter_by(simulation_id=sim_id, name=est_name).first()
140 return estimator
143def fetch_pages_by_estimator_id(est_id: int) -> list[PageModel]:
144 """Fetches pages by estimator id"""
145 pages = db.session.query(PageModel).filter_by(estimator_id=est_id).all()
146 return pages
149def fetch_page_by_est_id_and_page_number(est_id: int, page_number: int) -> PageModel:
150 """Fetches page by estimator id and page number"""
151 page = db.session.query(PageModel).filter_by(estimator_id=est_id, page_number=page_number).first()
152 return page
155def fetch_all_clusters() -> list[ClusterModel]:
156 """Fetches all clusters"""
157 clusters = db.session.query(ClusterModel).all()
158 return clusters
161def fetch_cluster_by_id(cluster_id: int) -> ClusterModel:
162 """Fetches cluster by id"""
163 cluster = db.session.query(ClusterModel).filter_by(id=cluster_id).first()
164 return cluster
167def fetch_input_by_sim_id(sim_id: int) -> InputModel:
168 """Fetches input by simulation id"""
169 input_model = db.session.query(InputModel).filter_by(simulation_id=sim_id).first()
170 return input_model
173def fetch_logfiles_by_sim_id(sim_id: int) -> LogfilesModel:
174 """Fetches logfiles by simulation id"""
175 logfiles = db.session.query(LogfilesModel).filter_by(simulation_id=sim_id).first()
176 return logfiles
179def update_task_state(task: Union[BatchTaskModel, CeleryTaskModel], update_dict: dict) -> None:
180 """Updates task state and makes commit"""
181 task.update_state(update_dict)
182 db.session.commit()
185def update_simulation_state(simulation: Union[BatchSimulationModel, CelerySimulationModel], update_dict: dict) -> None:
186 """Updates simulation state and makes commit"""
187 if simulation.update_state(update_dict):
188 db.session.commit()
189 else:
190 logging.warning("Simulation state not updated, skipping commit")