Coverage for yaptide/persistence/models.py: 90%
223 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-06-10 10:08 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-06-10 10:08 +0000
1# ---------- IMPORTANT ------------
2# Read documentation in persistency.md. It contains information about database development with flask-migrate.
4import gzip
5import json
6from datetime import datetime
8from sqlalchemy import Column, UniqueConstraint
9from sqlalchemy.orm import relationship
10from sqlalchemy.sql.functions import now
11from werkzeug.security import check_password_hash, generate_password_hash
13from yaptide.persistence.database import db
14from yaptide.utils.enums import EntityState, PlatformType
17class UserModel(db.Model):
18 """User model"""
20 __tablename__ = 'User'
21 id: Column[int] = db.Column(db.Integer, primary_key=True)
22 username: Column[str] = db.Column(db.String, nullable=False)
23 auth_provider: Column[str] = db.Column(db.String, nullable=False)
24 simulations = relationship("SimulationModel")
26 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), )
28 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"}
30 def __repr__(self) -> str:
31 return f'User #{self.id} {self.username}'
34class YaptideUserModel(UserModel, db.Model):
35 """Yaptide user model"""
37 __tablename__ = 'YaptideUser'
38 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
39 password_hash: Column[str] = db.Column(db.String, nullable=False)
41 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"}
43 def set_password(self, password: str):
44 """Sets hashed password"""
45 self.password_hash = generate_password_hash(password)
47 def check_password(self, password: str) -> bool:
48 """Checks password correctness"""
49 return check_password_hash(self.password_hash, password)
52class KeycloakUserModel(UserModel, db.Model):
53 """PLGrid user model"""
55 __tablename__ = 'KeycloakUser'
56 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
57 cert: Column[str] = db.Column(db.String, nullable=True)
58 private_key: Column[str] = db.Column(db.String, nullable=True)
60 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"}
63class ClusterModel(db.Model):
64 """Cluster info for specific user"""
66 __tablename__ = 'Cluster'
67 id: Column[int] = db.Column(db.Integer, primary_key=True)
68 cluster_name: Column[str] = db.Column(db.String, nullable=False)
69 simulations = relationship("BatchSimulationModel")
72class SimulationModel(db.Model):
73 """Simulation model"""
75 __tablename__ = 'Simulation'
77 id: Column[int] = db.Column(db.Integer, primary_key=True)
79 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID")
81 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID")
82 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time")
83 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
84 nullable=True,
85 doc="Job end time (including merging)")
86 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title")
87 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
88 input_type: Column[str] = db.Column(db.String,
89 nullable=False,
90 doc="Input type (i.e. 'yaptide_project', 'input_files')")
91 sim_type: Column[str] = db.Column(db.String,
92 nullable=False,
93 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')")
94 job_state: Column[str] = db.Column(db.String,
95 nullable=False,
96 default=EntityState.UNKNOWN.value,
97 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')")
99 tasks = relationship("TaskModel", cascade="delete")
100 estimators = relationship("EstimatorModel", cascade="delete")
101 inputs = relationship("InputModel", cascade="delete")
102 logfiles = relationship("LogfilesModel", cascade="delete")
104 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"}
106 def update_state(self, update_dict: dict) -> bool:
107 """
108 Updating database is more costly than a simple query.
109 Therefore we check first if update is needed and
110 perform it only for such fields which exists and which have updated values.
111 Returns bool value telling if it is required to commit changes to db.
112 """
113 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
114 return False
115 db_commit_required = False
116 if "job_state" in update_dict and self.job_state != update_dict["job_state"]:
117 self.job_state = update_dict["job_state"]
118 db_commit_required = True
119 # Here we have a special case, `end_time` can be set only once
120 # therefore we update it only if it not set previously (`self.end_time is None`)
121 # and if update was requested (`"end_time" in update_dict`)
122 if "end_time" in update_dict and self.end_time is None:
123 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
124 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
125 db_commit_required = True
126 return db_commit_required
129class CelerySimulationModel(SimulationModel):
130 """Celery simulation model"""
132 __tablename__ = 'CelerySimulation'
133 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
134 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID")
136 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
139class BatchSimulationModel(SimulationModel):
140 """Batch simulation model"""
142 __tablename__ = 'BatchSimulation'
143 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
144 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID")
145 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name")
146 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID")
147 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID")
149 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
151 def update_state(self, update_dict):
152 """Used to update fields in BatchSimulation. Returns boolean value if commit to database is reuqired"""
153 db_commit_required = super().update_state(update_dict)
154 if "job_dir" in update_dict and self.job_dir != update_dict["job_dir"]:
155 self.job_dir = update_dict["job_dir"]
156 db_commit_required = True
157 if "array_id" in update_dict and self.array_id != update_dict["array_id"]:
158 self.array_id = update_dict["array_id"]
159 db_commit_required = True
160 if "collect_id" in update_dict and self.collect_id != update_dict["collect_id"]:
161 self.collect_id = update_dict["collect_id"]
162 db_commit_required = True
163 return db_commit_required
166def allowed_state_change(current_state: str, next_state: str):
167 """Ensures that no such change like Completed -> Canceled happens"""
168 return not (current_state in [EntityState.FAILED.value, EntityState.COMPLETED.value]
169 and next_state in [EntityState.CANCELED])
172def value_changed(current_value: str, new_value: str):
173 """checks if value from update_dict differs from object in database"""
174 return new_value and current_value != new_value
177class TaskModel(db.Model):
178 """Simulation task model"""
180 __tablename__ = 'Task'
181 id: Column[int] = db.Column(db.Integer, primary_key=True)
182 simulation_id: Column[int] = db.Column(db.Integer,
183 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
184 doc="Simulation job ID (foreign key)")
186 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID")
187 requested_primaries: Column[int] = db.Column(db.Integer,
188 nullable=False,
189 default=0,
190 doc="Requested number of primaries")
191 simulated_primaries: Column[int] = db.Column(db.Integer,
192 nullable=False,
193 default=0,
194 doc="Simulated number of primaries")
195 task_state: Column[str] = db.Column(db.String,
196 nullable=False,
197 default=EntityState.PENDING.value,
198 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')")
199 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds")
200 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time")
201 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time")
202 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
203 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
204 default=now(),
205 doc="Task last update time")
207 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), )
209 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"}
211 def update_state(self, update_dict: dict): # skipcq: PY-R1000
212 """
213 Updating database is more costly than a simple query.
214 Therefore we check first if update is needed and
215 perform it only for such fields which exists and which have updated values.
216 """
217 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
218 return
219 if value_changed(self.requested_primaries, update_dict.get("requested_primaries")):
220 self.requested_primaries = update_dict["requested_primaries"]
221 if value_changed(self.simulated_primaries, update_dict.get("simulated_primaries")):
222 self.simulated_primaries = update_dict["simulated_primaries"]
223 if value_changed(self.task_state, update_dict.get("task_state")) and allowed_state_change(
224 self.task_state, update_dict["task_state"]):
225 self.task_state = update_dict["task_state"]
226 if self.task_state == EntityState.COMPLETED.value:
227 self.simulated_primaries = self.requested_primaries
228 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless
229 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"]
230 end_time_not_set = self.end_time is None
231 if have_estim_time and end_time_not_set:
232 self.estimated_time = update_dict["estimated_time"]
233 if "start_time" in update_dict and self.start_time is None:
234 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format
235 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f')
236 # Here we have a special case, `end_time` can be set only once
237 # therefore we update it only if it not set previously (`self.end_time is None`)
238 # and if update was requested (`"end_time" in update_dict`)
239 if "end_time" in update_dict and self.end_time is None:
240 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
241 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
242 self.estimated_time = None
243 self.last_update_time = now()
245 def get_status_dict(self) -> dict:
246 """Returns task information as a dictionary"""
247 result = {
248 "task_state": self.task_state,
249 "requested_primaries": self.requested_primaries,
250 "simulated_primaries": self.simulated_primaries,
251 "last_update_time": self.last_update_time,
252 "task_id": self.id
253 }
254 if self.estimated_time:
255 result["estimated_time"] = {
256 "hours": self.estimated_time // 3600,
257 "minutes": (self.estimated_time // 60) % 60,
258 "seconds": self.estimated_time % 60,
259 }
260 if self.start_time:
261 result["start_time"] = self.start_time
262 if self.end_time:
263 result["end_time"] = self.end_time
264 return result
267class CeleryTaskModel(TaskModel):
268 """Celery task model"""
270 __tablename__ = 'CeleryTask'
271 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
272 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID")
274 def update_state(self, update_dict: dict):
275 """Update method for CeleryTaskModel"""
276 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]:
277 self.celery_id = update_dict["celery_id"]
278 return super().update_state(update_dict)
280 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
283class BatchTaskModel(TaskModel):
284 """Batch task model"""
286 __tablename__ = 'BatchTask'
287 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
289 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
292def decompress(data: bytes):
293 """Decompresses data and deserializes JSON"""
294 data_to_unpack: str = 'null'
295 if data is not None:
296 # Decompress the data
297 decompressed_bytes: bytes = gzip.decompress(data)
298 data_to_unpack = decompressed_bytes.decode('utf-8')
299 # Deserialize the JSON
300 return json.loads(data_to_unpack)
303def compress(data) -> bytes:
304 """Serializes JSON and compresses data"""
305 compressed_bytes = b''
306 if data is not None:
307 # Serialize the JSON
308 serialized_data: str = json.dumps(data)
309 # Compress the data
310 bytes_to_compress: bytes = serialized_data.encode('utf-8')
311 compressed_bytes = gzip.compress(bytes_to_compress)
312 return compressed_bytes
315class InputModel(db.Model):
316 """Simulation inputs model"""
318 __tablename__ = 'Input'
319 id: Column[int] = db.Column(db.Integer, primary_key=True)
320 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"))
321 compressed_data: Column[bytes] = db.Column(db.LargeBinary)
323 @property
324 def data(self):
325 return decompress(self.compressed_data)
327 @data.setter
328 def data(self, value):
329 if value is not None:
330 self.compressed_data = compress(value)
333class EstimatorModel(db.Model):
334 """Simulation single estimator model"""
336 __tablename__ = 'Estimator'
337 id: Column[int] = db.Column(db.Integer, primary_key=True)
338 simulation_id: Column[int] = db.Column(db.Integer,
339 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
340 nullable=False)
341 name: Column[str] = db.Column(db.String, nullable=False, doc="Human readable estimator name")
342 file_name: Column[str] = db.Column(db.String,
343 nullable=False,
344 doc="Estimator name extracted from file generated by simulator")
345 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata")
346 pages = relationship("PageModel", cascade="delete")
348 @property
349 def data(self):
350 return decompress(self.compressed_data)
352 @data.setter
353 def data(self, value):
354 if value is not None:
355 self.compressed_data = compress(value)
358class PageModel(db.Model):
359 """Estimator single page model"""
361 __tablename__ = 'Page'
362 id: Column[int] = db.Column(db.Integer, primary_key=True)
363 page_name: Column[str] = db.Column(db.String, nullable=False, doc="Page name")
364 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id', ondelete="CASCADE"), nullable=False)
365 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number")
366 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata")
367 page_dimension: Column[int] = db.Column(db.Integer, nullable=False, doc="Dimension of data")
369 @property
370 def data(self):
371 return decompress(self.compressed_data)
373 @data.setter
374 def data(self, value):
375 if value is not None:
376 self.compressed_data = compress(value)
379class LogfilesModel(db.Model):
380 """Simulation logfiles model"""
382 __tablename__ = 'Logfiles'
383 id: Column[int] = db.Column(db.Integer, primary_key=True)
384 simulation_id: Column[int] = db.Column(db.Integer,
385 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
386 nullable=False)
387 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles")
389 @property
390 def data(self):
391 return decompress(self.compressed_data)
393 @data.setter
394 def data(self, value):
395 if value is not None:
396 self.compressed_data = compress(value)
399def create_all():
400 """Creates all tables, to be used with Flask app context."""
401 db.create_all()