Coverage for yaptide/persistence/models.py: 90%
223 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1# ---------- IMPORTANT ------------
2# Read documentation in persistency.md. It contains information about database development with flask-migrate.
4import gzip
5import json
6from datetime import datetime
8from sqlalchemy import Column, UniqueConstraint
9from sqlalchemy.orm import relationship
10from sqlalchemy.sql.functions import now
11from werkzeug.security import check_password_hash, generate_password_hash
13from yaptide.persistence.database import db
14from yaptide.utils.enums import EntityState, PlatformType
17class UserModel(db.Model):
18 """User model"""
20 __tablename__ = 'User'
21 id: Column[int] = db.Column(db.Integer, primary_key=True)
22 username: Column[str] = db.Column(db.String, nullable=False)
23 auth_provider: Column[str] = db.Column(db.String, nullable=False)
24 simulations = relationship("SimulationModel")
26 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), )
28 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"}
30 def __repr__(self) -> str:
31 return f'User #{self.id} {self.username}'
34class YaptideUserModel(UserModel, db.Model):
35 """Yaptide user model"""
37 __tablename__ = 'YaptideUser'
38 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
39 password_hash: Column[str] = db.Column(db.String, nullable=False)
41 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"}
43 def set_password(self, password: str):
44 """Sets hashed password"""
45 self.password_hash = generate_password_hash(password)
47 def check_password(self, password: str) -> bool:
48 """Checks password correctness"""
49 return check_password_hash(self.password_hash, password)
52class KeycloakUserModel(UserModel, db.Model):
53 """PLGrid user model"""
55 __tablename__ = 'KeycloakUser'
56 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
57 cert: Column[str] = db.Column(db.String, nullable=True)
58 private_key: Column[str] = db.Column(db.String, nullable=True)
60 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"}
63class ClusterModel(db.Model):
64 """Cluster info for specific user"""
66 __tablename__ = 'Cluster'
67 id: Column[int] = db.Column(db.Integer, primary_key=True)
68 cluster_name: Column[str] = db.Column(db.String, nullable=False)
69 simulations = relationship("BatchSimulationModel")
72class SimulationModel(db.Model):
73 """Simulation model"""
75 __tablename__ = 'Simulation'
77 id: Column[int] = db.Column(db.Integer, primary_key=True)
79 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID")
81 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID")
82 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time")
83 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
84 nullable=True,
85 doc="Job end time (including merging)")
86 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title")
87 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
88 input_type: Column[str] = db.Column(db.String,
89 nullable=False,
90 doc="Input type (i.e. 'yaptide_project', 'input_files')")
91 sim_type: Column[str] = db.Column(db.String,
92 nullable=False,
93 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')")
94 job_state: Column[str] = db.Column(db.String,
95 nullable=False,
96 default=EntityState.UNKNOWN.value,
97 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')")
99 tasks = relationship("TaskModel", cascade="delete")
100 estimators = relationship("EstimatorModel", cascade="delete")
101 inputs = relationship("InputModel", cascade="delete")
102 logfiles = relationship("LogfilesModel", cascade="delete")
104 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"}
106 def update_state(self, update_dict: dict) -> bool:
107 """
108 Updating database is more costly than a simple query.
109 Therefore we check first if update is needed and
110 perform it only for such fields which exists and which have updated values.
111 Returns bool value telling if it is required to commit changes to db.
112 """
113 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
114 return False
115 db_commit_required = False
116 if "job_state" in update_dict and self.job_state != update_dict["job_state"]:
117 self.job_state = update_dict["job_state"]
118 db_commit_required = True
119 # Here we have a special case, `end_time` can be set only once
120 # therefore we update it only if it not set previously (`self.end_time is None`)
121 # and if update was requested (`"end_time" in update_dict`)
122 if "end_time" in update_dict and self.end_time is None:
123 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
124 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
125 db_commit_required = True
126 return db_commit_required
129class CelerySimulationModel(SimulationModel):
130 """Celery simulation model"""
132 __tablename__ = 'CelerySimulation'
133 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
134 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID")
136 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
139class BatchSimulationModel(SimulationModel):
140 """Batch simulation model"""
142 __tablename__ = 'BatchSimulation'
143 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
144 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID")
145 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name")
146 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID")
147 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID")
149 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
151 def update_state(self, update_dict):
152 """Used to update fields in BatchSimulation. Returns boolean value if commit to database is reuqired"""
153 db_commit_required = super().update_state(update_dict)
154 if "job_dir" in update_dict and self.job_dir != update_dict["job_dir"]:
155 self.job_dir = update_dict["job_dir"]
156 db_commit_required = True
157 if "array_id" in update_dict and self.array_id != update_dict["array_id"]:
158 self.array_id = update_dict["array_id"]
159 db_commit_required = True
160 if "collect_id" in update_dict and self.collect_id != update_dict["collect_id"]:
161 self.collect_id = update_dict["collect_id"]
162 db_commit_required = True
163 return db_commit_required
166def allowed_state_change(current_state: str, next_state: str):
167 """Ensures that no such change like Completed -> Canceled happens"""
168 return not (current_state in [EntityState.FAILED.value, EntityState.COMPLETED.value]
169 and next_state in [EntityState.CANCELED])
172def value_changed(current_value: str, new_value: str):
173 """checks if value from update_dict differs from object in database"""
174 return new_value and current_value != new_value
177class TaskModel(db.Model):
178 """Simulation task model"""
180 __tablename__ = 'Task'
181 id: Column[int] = db.Column(db.Integer, primary_key=True)
182 simulation_id: Column[int] = db.Column(db.Integer,
183 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
184 doc="Simulation job ID (foreign key)")
186 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID")
187 requested_primaries: Column[int] = db.Column(db.Integer,
188 nullable=False,
189 default=0,
190 doc="Requested number of primaries")
191 simulated_primaries: Column[int] = db.Column(db.Integer,
192 nullable=False,
193 default=0,
194 doc="Simulated number of primaries")
195 task_state: Column[str] = db.Column(db.String,
196 nullable=False,
197 default=EntityState.PENDING.value,
198 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')")
199 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds")
200 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time")
201 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time")
202 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
203 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
204 default=now(),
205 doc="Task last update time")
207 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), )
209 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"}
211 def update_state(self, update_dict: dict): # skipcq: PY-R1000
212 """
213 Updating database is more costly than a simple query.
214 Therefore we check first if update is needed and
215 perform it only for such fields which exists and which have updated values.
216 """
217 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
218 return
219 if value_changed(self.requested_primaries, update_dict.get("requested_primaries")):
220 self.requested_primaries = update_dict["requested_primaries"]
221 if value_changed(self.simulated_primaries, update_dict.get("simulated_primaries")):
222 self.simulated_primaries = update_dict["simulated_primaries"]
223 if value_changed(self.task_state, update_dict.get("task_state")) and allowed_state_change(
224 self.task_state, update_dict["task_state"]):
225 self.task_state = update_dict["task_state"]
226 if self.task_state == EntityState.COMPLETED.value:
227 self.simulated_primaries = self.requested_primaries
228 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless
229 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"]
230 end_time_not_set = self.end_time is None
231 if have_estim_time and end_time_not_set:
232 self.estimated_time = update_dict["estimated_time"]
233 if "start_time" in update_dict and self.start_time is None:
234 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format
235 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f')
236 # Here we have a special case, `end_time` can be set only once
237 # therefore we update it only if it not set previously (`self.end_time is None`)
238 # and if update was requested (`"end_time" in update_dict`)
239 if "end_time" in update_dict and self.end_time is None:
240 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
241 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
242 self.estimated_time = None
243 self.last_update_time = now()
245 def get_status_dict(self) -> dict:
246 """Returns task information as a dictionary"""
247 result = {
248 "task_state": self.task_state,
249 "requested_primaries": self.requested_primaries,
250 "simulated_primaries": self.simulated_primaries,
251 "last_update_time": self.last_update_time,
252 }
253 if self.estimated_time:
254 result["estimated_time"] = {
255 "hours": self.estimated_time // 3600,
256 "minutes": (self.estimated_time // 60) % 60,
257 "seconds": self.estimated_time % 60,
258 }
259 if self.start_time:
260 result["start_time"] = self.start_time
261 if self.end_time:
262 result["end_time"] = self.end_time
263 return result
266class CeleryTaskModel(TaskModel):
267 """Celery task model"""
269 __tablename__ = 'CeleryTask'
270 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
271 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID")
273 def update_state(self, update_dict: dict):
274 """Update method for CeleryTaskModel"""
275 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]:
276 self.celery_id = update_dict["celery_id"]
277 return super().update_state(update_dict)
279 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
282class BatchTaskModel(TaskModel):
283 """Batch task model"""
285 __tablename__ = 'BatchTask'
286 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
288 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
291def decompress(data: bytes):
292 """Decompresses data and deserializes JSON"""
293 data_to_unpack: str = 'null'
294 if data is not None:
295 # Decompress the data
296 decompressed_bytes: bytes = gzip.decompress(data)
297 data_to_unpack = decompressed_bytes.decode('utf-8')
298 # Deserialize the JSON
299 return json.loads(data_to_unpack)
302def compress(data) -> bytes:
303 """Serializes JSON and compresses data"""
304 compressed_bytes = b''
305 if data is not None:
306 # Serialize the JSON
307 serialized_data: str = json.dumps(data)
308 # Compress the data
309 bytes_to_compress: bytes = serialized_data.encode('utf-8')
310 compressed_bytes = gzip.compress(bytes_to_compress)
311 return compressed_bytes
314class InputModel(db.Model):
315 """Simulation inputs model"""
317 __tablename__ = 'Input'
318 id: Column[int] = db.Column(db.Integer, primary_key=True)
319 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"))
320 compressed_data: Column[bytes] = db.Column(db.LargeBinary)
322 @property
323 def data(self):
324 return decompress(self.compressed_data)
326 @data.setter
327 def data(self, value):
328 if value is not None:
329 self.compressed_data = compress(value)
332class EstimatorModel(db.Model):
333 """Simulation single estimator model"""
335 __tablename__ = 'Estimator'
336 id: Column[int] = db.Column(db.Integer, primary_key=True)
337 simulation_id: Column[int] = db.Column(db.Integer,
338 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
339 nullable=False)
340 name: Column[str] = db.Column(db.String, nullable=False, doc="Human readable estimator name")
341 file_name: Column[str] = db.Column(db.String,
342 nullable=False,
343 doc="Estimator name extracted from file generated by simulator")
344 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata")
345 pages = relationship("PageModel", cascade="delete")
347 @property
348 def data(self):
349 return decompress(self.compressed_data)
351 @data.setter
352 def data(self, value):
353 if value is not None:
354 self.compressed_data = compress(value)
357class PageModel(db.Model):
358 """Estimator single page model"""
360 __tablename__ = 'Page'
361 id: Column[int] = db.Column(db.Integer, primary_key=True)
362 page_name: Column[str] = db.Column(db.String, nullable=False, doc="Page name")
363 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id', ondelete="CASCADE"), nullable=False)
364 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number")
365 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata")
366 page_dimension: Column[int] = db.Column(db.Integer, nullable=False, doc="Dimension of data")
368 @property
369 def data(self):
370 return decompress(self.compressed_data)
372 @data.setter
373 def data(self, value):
374 if value is not None:
375 self.compressed_data = compress(value)
378class LogfilesModel(db.Model):
379 """Simulation logfiles model"""
381 __tablename__ = 'Logfiles'
382 id: Column[int] = db.Column(db.Integer, primary_key=True)
383 simulation_id: Column[int] = db.Column(db.Integer,
384 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
385 nullable=False)
386 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles")
388 @property
389 def data(self):
390 return decompress(self.compressed_data)
392 @data.setter
393 def data(self, value):
394 if value is not None:
395 self.compressed_data = compress(value)
398def create_all():
399 """Creates all tables, to be used with Flask app context."""
400 db.create_all()