Coverage for yaptide/persistence/models.py: 90%
220 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-22 07:31 +0000
1# ---------- IMPORTANT ------------
2# Read documentation in persistency.md. It contains information about database development with flask-migrate.
4import gzip
5import json
6from datetime import datetime
8from sqlalchemy import Column, UniqueConstraint
9from sqlalchemy.orm import relationship
10from sqlalchemy.sql.functions import now
11from werkzeug.security import check_password_hash, generate_password_hash
13from yaptide.persistence.database import db
14from yaptide.utils.enums import EntityState, PlatformType
17class UserModel(db.Model):
18 """User model"""
20 __tablename__ = 'User'
21 id: Column[int] = db.Column(db.Integer, primary_key=True)
22 username: Column[str] = db.Column(db.String, nullable=False)
23 auth_provider: Column[str] = db.Column(db.String, nullable=False)
24 simulations = relationship("SimulationModel")
26 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), )
28 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"}
30 def __repr__(self) -> str:
31 return f'User #{self.id} {self.username}'
34class YaptideUserModel(UserModel, db.Model):
35 """Yaptide user model"""
37 __tablename__ = 'YaptideUser'
38 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
39 password_hash: Column[str] = db.Column(db.String, nullable=False)
41 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"}
43 def set_password(self, password: str):
44 """Sets hashed password"""
45 self.password_hash = generate_password_hash(password)
47 def check_password(self, password: str) -> bool:
48 """Checks password correctness"""
49 return check_password_hash(self.password_hash, password)
52class KeycloakUserModel(UserModel, db.Model):
53 """PLGrid user model"""
55 __tablename__ = 'KeycloakUser'
56 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
57 cert: Column[str] = db.Column(db.String, nullable=True)
58 private_key: Column[str] = db.Column(db.String, nullable=True)
60 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"}
63class ClusterModel(db.Model):
64 """Cluster info for specific user"""
66 __tablename__ = 'Cluster'
67 id: Column[int] = db.Column(db.Integer, primary_key=True)
68 cluster_name: Column[str] = db.Column(db.String, nullable=False)
69 simulations = relationship("BatchSimulationModel")
72class SimulationModel(db.Model):
73 """Simulation model"""
75 __tablename__ = 'Simulation'
77 id: Column[int] = db.Column(db.Integer, primary_key=True)
79 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID")
81 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID")
82 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time")
83 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
84 nullable=True,
85 doc="Job end time (including merging)")
86 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title")
87 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
88 input_type: Column[str] = db.Column(db.String,
89 nullable=False,
90 doc="Input type (i.e. 'yaptide_project', 'input_files')")
91 sim_type: Column[str] = db.Column(db.String,
92 nullable=False,
93 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')")
94 job_state: Column[str] = db.Column(db.String,
95 nullable=False,
96 default=EntityState.UNKNOWN.value,
97 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')")
99 tasks = relationship("TaskModel", cascade="delete")
100 estimators = relationship("EstimatorModel", cascade="delete")
101 inputs = relationship("InputModel", cascade="delete")
102 logfiles = relationship("LogfilesModel", cascade="delete")
104 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"}
106 def update_state(self, update_dict: dict) -> bool:
107 """
108 Updating database is more costly than a simple query.
109 Therefore we check first if update is needed and
110 perform it only for such fields which exists and which have updated values.
111 Returns bool value telling if it is required to commit changes to db.
112 """
113 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
114 return False
115 db_commit_required = False
116 if "job_state" in update_dict and self.job_state != update_dict["job_state"]:
117 self.job_state = update_dict["job_state"]
118 db_commit_required = True
119 # Here we have a special case, `end_time` can be set only once
120 # therefore we update it only if it not set previously (`self.end_time is None`)
121 # and if update was requested (`"end_time" in update_dict`)
122 if "end_time" in update_dict and self.end_time is None:
123 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
124 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
125 db_commit_required = True
126 return db_commit_required
129class CelerySimulationModel(SimulationModel):
130 """Celery simulation model"""
132 __tablename__ = 'CelerySimulation'
133 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
134 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID")
136 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
139class BatchSimulationModel(SimulationModel):
140 """Batch simulation model"""
142 __tablename__ = 'BatchSimulation'
143 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
144 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID")
145 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name")
146 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID")
147 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID")
149 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
151 def update_state(self, update_dict):
152 """Used to update fields in BatchSimulation. Returns boolean value if commit to database is reuqired"""
153 db_commit_required = super().update_state(update_dict)
154 if "job_dir" in update_dict and self.job_dir != update_dict["job_dir"]:
155 self.job_dir = update_dict["job_dir"]
156 db_commit_required = True
157 if "array_id" in update_dict and self.array_id != update_dict["array_id"]:
158 self.array_id = update_dict["array_id"]
159 db_commit_required = True
160 if "collect_id" in update_dict and self.collect_id != update_dict["collect_id"]:
161 self.collect_id = update_dict["collect_id"]
162 db_commit_required = True
163 return db_commit_required
166def allowed_state_change(current_state: str, next_state: str):
167 """Ensures that no such change like Completed -> Canceled happens"""
168 return not (current_state in [EntityState.FAILED.value, EntityState.COMPLETED.value]
169 and next_state in [EntityState.CANCELED])
172def value_changed(current_value: str, new_value: str):
173 """checks if value from update_dict differs from object in database"""
174 return new_value and current_value != new_value
177class TaskModel(db.Model):
178 """Simulation task model"""
180 __tablename__ = 'Task'
181 id: Column[int] = db.Column(db.Integer, primary_key=True)
182 simulation_id: Column[int] = db.Column(db.Integer,
183 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
184 doc="Simulation job ID (foreign key)")
186 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID")
187 requested_primaries: Column[int] = db.Column(db.Integer,
188 nullable=False,
189 default=0,
190 doc="Requested number of primaries")
191 simulated_primaries: Column[int] = db.Column(db.Integer,
192 nullable=False,
193 default=0,
194 doc="Simulated number of primaries")
195 task_state: Column[str] = db.Column(db.String,
196 nullable=False,
197 default=EntityState.PENDING.value,
198 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')")
199 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds")
200 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time")
201 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time")
202 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
203 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
204 default=now(),
205 doc="Task last update time")
207 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), )
209 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"}
211 def update_state(self, update_dict: dict): # skipcq: PY-R1000
212 """
213 Updating database is more costly than a simple query.
214 Therefore we check first if update is needed and
215 perform it only for such fields which exists and which have updated values.
216 """
217 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
218 return
219 if value_changed(self.requested_primaries, update_dict.get("requested_primaries")):
220 self.requested_primaries = update_dict["requested_primaries"]
221 if value_changed(self.simulated_primaries, update_dict.get("simulated_primaries")):
222 self.simulated_primaries = update_dict["simulated_primaries"]
223 if value_changed(self.task_state, update_dict.get("task_state")) and allowed_state_change(
224 self.task_state, update_dict["task_state"]):
225 self.task_state = update_dict["task_state"]
226 if self.task_state == EntityState.COMPLETED.value:
227 self.simulated_primaries = self.requested_primaries
228 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless
229 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"]
230 end_time_not_set = self.end_time is None
231 if have_estim_time and end_time_not_set:
232 self.estimated_time = update_dict["estimated_time"]
233 if "start_time" in update_dict and self.start_time is None:
234 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format
235 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f')
236 # Here we have a special case, `end_time` can be set only once
237 # therefore we update it only if it not set previously (`self.end_time is None`)
238 # and if update was requested (`"end_time" in update_dict`)
239 if "end_time" in update_dict and self.end_time is None:
240 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
241 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
242 self.estimated_time = None
243 self.last_update_time = now()
245 def get_status_dict(self) -> dict:
246 """Returns task information as a dictionary"""
247 result = {
248 "task_state": self.task_state,
249 "requested_primaries": self.requested_primaries,
250 "simulated_primaries": self.simulated_primaries,
251 "last_update_time": self.last_update_time,
252 }
253 if self.estimated_time:
254 result["estimated_time"] = {
255 "hours": self.estimated_time // 3600,
256 "minutes": (self.estimated_time // 60) % 60,
257 "seconds": self.estimated_time % 60,
258 }
259 if self.start_time:
260 result["start_time"] = self.start_time
261 if self.end_time:
262 result["end_time"] = self.end_time
263 return result
266class CeleryTaskModel(TaskModel):
267 """Celery task model"""
269 __tablename__ = 'CeleryTask'
270 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
271 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID")
273 def update_state(self, update_dict: dict):
274 """Update method for CeleryTaskModel"""
275 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]:
276 self.celery_id = update_dict["celery_id"]
277 return super().update_state(update_dict)
279 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
282class BatchTaskModel(TaskModel):
283 """Batch task model"""
285 __tablename__ = 'BatchTask'
286 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
288 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
291def decompress(data: bytes):
292 """Decompresses data and deserializes JSON"""
293 data_to_unpack: str = 'null'
294 if data is not None:
295 # Decompress the data
296 decompressed_bytes: bytes = gzip.decompress(data)
297 data_to_unpack = decompressed_bytes.decode('utf-8')
298 # Deserialize the JSON
299 return json.loads(data_to_unpack)
302def compress(data) -> bytes:
303 """Serializes JSON and compresses data"""
304 compressed_bytes = b''
305 if data is not None:
306 # Serialize the JSON
307 serialized_data: str = json.dumps(data)
308 # Compress the data
309 bytes_to_compress: bytes = serialized_data.encode('utf-8')
310 compressed_bytes = gzip.compress(bytes_to_compress)
311 return compressed_bytes
314class InputModel(db.Model):
315 """Simulation inputs model"""
317 __tablename__ = 'Input'
318 id: Column[int] = db.Column(db.Integer, primary_key=True)
319 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"))
320 compressed_data: Column[bytes] = db.Column(db.LargeBinary)
322 @property
323 def data(self):
324 return decompress(self.compressed_data)
326 @data.setter
327 def data(self, value):
328 if value is not None:
329 self.compressed_data = compress(value)
332class EstimatorModel(db.Model):
333 """Simulation single estimator model"""
335 __tablename__ = 'Estimator'
336 id: Column[int] = db.Column(db.Integer, primary_key=True)
337 simulation_id: Column[int] = db.Column(db.Integer,
338 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
339 nullable=False)
340 name: Column[str] = db.Column(db.String, nullable=False, doc="Estimator name")
341 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata")
342 pages = relationship("PageModel", cascade="delete")
344 @property
345 def data(self):
346 return decompress(self.compressed_data)
348 @data.setter
349 def data(self, value):
350 if value is not None:
351 self.compressed_data = compress(value)
354class PageModel(db.Model):
355 """Estimator single page model"""
357 __tablename__ = 'Page'
358 id: Column[int] = db.Column(db.Integer, primary_key=True)
359 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id', ondelete="CASCADE"), nullable=False)
360 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number")
361 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata")
363 @property
364 def data(self):
365 return decompress(self.compressed_data)
367 @data.setter
368 def data(self, value):
369 if value is not None:
370 self.compressed_data = compress(value)
373class LogfilesModel(db.Model):
374 """Simulation logfiles model"""
376 __tablename__ = 'Logfiles'
377 id: Column[int] = db.Column(db.Integer, primary_key=True)
378 simulation_id: Column[int] = db.Column(db.Integer,
379 db.ForeignKey('Simulation.id', ondelete="CASCADE"),
380 nullable=False)
381 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles")
383 @property
384 def data(self):
385 return decompress(self.compressed_data)
387 @data.setter
388 def data(self, value):
389 if value is not None:
390 self.compressed_data = compress(value)
393def create_all():
394 """Creates all tables, to be used with Flask app context."""
395 db.create_all()