Coverage for yaptide/persistence/models.py: 94%
204 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-07-01 12:55 +0000
1import gzip
2import json
3from datetime import datetime
5from sqlalchemy import Column, UniqueConstraint
6from sqlalchemy.orm import relationship
7from sqlalchemy.sql.functions import now
8from werkzeug.security import check_password_hash, generate_password_hash
10from yaptide.persistence.database import db
11from yaptide.utils.enums import EntityState, PlatformType
14class UserModel(db.Model):
15 """User model"""
17 __tablename__ = 'User'
18 id: Column[int] = db.Column(db.Integer, primary_key=True)
19 username: Column[str] = db.Column(db.String, nullable=False)
20 auth_provider: Column[str] = db.Column(db.String, nullable=False)
21 simulations = relationship("SimulationModel")
23 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), )
25 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"}
27 def __repr__(self) -> str:
28 return f'User #{self.id} {self.username}'
31class YaptideUserModel(UserModel, db.Model):
32 """Yaptide user model"""
34 __tablename__ = 'YaptideUser'
35 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
36 password_hash: Column[str] = db.Column(db.String, nullable=False)
38 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"}
40 def set_password(self, password: str):
41 """Sets hashed password"""
42 self.password_hash = generate_password_hash(password)
44 def check_password(self, password: str) -> bool:
45 """Checks password correctness"""
46 return check_password_hash(self.password_hash, password)
49class KeycloakUserModel(UserModel, db.Model):
50 """PLGrid user model"""
52 __tablename__ = 'KeycloakUser'
53 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True)
54 cert: Column[str] = db.Column(db.String, nullable=True)
55 private_key: Column[str] = db.Column(db.String, nullable=True)
57 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"}
60class ClusterModel(db.Model):
61 """Cluster info for specific user"""
63 __tablename__ = 'Cluster'
64 id: Column[int] = db.Column(db.Integer, primary_key=True)
65 cluster_name: Column[str] = db.Column(db.String, nullable=False)
66 simulations = relationship("BatchSimulationModel")
69class SimulationModel(db.Model):
70 """Simulation model"""
72 __tablename__ = 'Simulation'
74 id: Column[int] = db.Column(db.Integer, primary_key=True)
76 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID")
78 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID")
79 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time")
80 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
81 nullable=True,
82 doc="Job end time (including merging)")
83 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title")
84 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
85 input_type: Column[str] = db.Column(db.String,
86 nullable=False,
87 doc="Input type (i.e. 'yaptide_project', 'input_files')")
88 sim_type: Column[str] = db.Column(db.String,
89 nullable=False,
90 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')")
91 job_state: Column[str] = db.Column(db.String,
92 nullable=False,
93 default=EntityState.UNKNOWN.value,
94 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')")
95 update_key_hash: Column[str] = db.Column(db.String,
96 doc="Update key shared by tasks granting access to update themselves")
97 tasks = relationship("TaskModel")
98 estimators = relationship("EstimatorModel")
100 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"}
102 def set_update_key(self, update_key: str):
103 """Sets hashed update key"""
104 self.update_key_hash = generate_password_hash(update_key)
106 def check_update_key(self, update_key: str) -> bool:
107 """Checks update key correctness"""
108 return check_password_hash(self.update_key_hash, update_key)
110 def update_state(self, update_dict: dict) -> bool:
111 """
112 Updating database is more costly than a simple query.
113 Therefore we check first if update is needed and
114 perform it only for such fields which exists and which have updated values.
115 Returns bool value telling if it is required to commit changes to db.
116 """
117 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
118 return False
119 db_commit_required = False
120 if "job_state" in update_dict and self.job_state != update_dict["job_state"]:
121 self.job_state = update_dict["job_state"]
122 db_commit_required = True
123 # Here we have a special case, `end_time` can be set only once
124 # therefore we update it only if it not set previously (`self.end_time is None`)
125 # and if update was requested (`"end_time" in update_dict`)
126 if "end_time" in update_dict and self.end_time is None:
127 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
128 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
129 db_commit_required = True
130 return db_commit_required
133class CelerySimulationModel(SimulationModel):
134 """Celery simulation model"""
136 __tablename__ = 'CelerySimulation'
137 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
138 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID")
140 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
143class BatchSimulationModel(SimulationModel):
144 """Batch simulation model"""
146 __tablename__ = 'BatchSimulation'
147 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True)
148 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID")
149 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name")
150 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID")
151 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID")
153 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
156class TaskModel(db.Model):
157 """Simulation task model"""
159 __tablename__ = 'Task'
160 id: Column[int] = db.Column(db.Integer, primary_key=True)
161 simulation_id: Column[int] = db.Column(db.Integer,
162 db.ForeignKey('Simulation.id'),
163 doc="Simulation job ID (foreign key)")
165 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID")
166 requested_primaries: Column[int] = db.Column(db.Integer,
167 nullable=False,
168 default=0,
169 doc="Requested number of primaries")
170 simulated_primaries: Column[int] = db.Column(db.Integer,
171 nullable=False,
172 default=0,
173 doc="Simulated number of primaries")
174 task_state: Column[str] = db.Column(db.String,
175 nullable=False,
176 default=EntityState.PENDING.value,
177 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')")
178 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds")
179 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time")
180 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time")
181 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')")
182 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True),
183 default=now(),
184 doc="Task last update time")
186 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), )
188 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"}
190 def update_state(self, update_dict: dict):
191 """
192 Updating database is more costly than a simple query.
193 Therefore we check first if update is needed and
194 perform it only for such fields which exists and which have updated values.
195 """
196 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value):
197 return
198 if "requested_primaries" in update_dict and self.requested_primaries != update_dict["requested_primaries"]:
199 self.requested_primaries = update_dict["requested_primaries"]
200 if "simulated_primaries" in update_dict and self.simulated_primaries != update_dict["simulated_primaries"]:
201 self.simulated_primaries = update_dict["simulated_primaries"]
202 if "task_state" in update_dict and self.task_state != update_dict["task_state"]:
203 self.task_state = update_dict["task_state"]
204 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless
205 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"]
206 end_time_not_set = self.end_time is None
207 if have_estim_time and end_time_not_set:
208 self.estimated_time = update_dict["estimated_time"]
209 if "start_time" in update_dict and self.start_time is None:
210 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format
211 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f')
212 # Here we have a special case, `end_time` can be set only once
213 # therefore we update it only if it not set previously (`self.end_time is None`)
214 # and if update was requested (`"end_time" in update_dict`)
215 if "end_time" in update_dict and self.end_time is None:
216 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format
217 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f')
218 self.estimated_time = None
219 self.last_update_time = now()
221 def get_status_dict(self) -> dict:
222 """Returns task information as a dictionary"""
223 result = {
224 "task_state": self.task_state,
225 "requested_primaries": self.requested_primaries,
226 "simulated_primaries": self.simulated_primaries,
227 "last_update_time": self.last_update_time,
228 }
229 if self.estimated_time:
230 result["estimated_time"] = {
231 "hours": self.estimated_time // 3600,
232 "minutes": (self.estimated_time // 60) % 60,
233 "seconds": self.estimated_time % 60,
234 }
235 if self.start_time:
236 result["start_time"] = self.start_time
237 if self.end_time:
238 result["end_time"] = self.end_time
239 return result
242class CeleryTaskModel(TaskModel):
243 """Celery task model"""
245 __tablename__ = 'CeleryTask'
246 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
247 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID")
249 def update_state(self, update_dict: dict):
250 """Update method for CeleryTaskModel"""
251 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]:
252 self.celery_id = update_dict["celery_id"]
253 return super().update_state(update_dict)
255 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"}
258class BatchTaskModel(TaskModel):
259 """Batch task model"""
261 __tablename__ = 'BatchTask'
262 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True)
264 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"}
267def decompress(data: bytes):
268 """Decompresses data and deserializes JSON"""
269 data_to_unpack: str = 'null'
270 if data is not None:
271 # Decompress the data
272 decompressed_bytes: bytes = gzip.decompress(data)
273 data_to_unpack = decompressed_bytes.decode('utf-8')
274 # Deserialize the JSON
275 return json.loads(data_to_unpack)
278def compress(data) -> bytes:
279 """Serializes JSON and compresses data"""
280 compressed_bytes = b''
281 if data is not None:
282 # Serialize the JSON
283 serialized_data: str = json.dumps(data)
284 # Compress the data
285 bytes_to_compress: bytes = serialized_data.encode('utf-8')
286 compressed_bytes = gzip.compress(bytes_to_compress)
287 return compressed_bytes
290class InputModel(db.Model):
291 """Simulation inputs model"""
293 __tablename__ = 'Input'
294 id: Column[int] = db.Column(db.Integer, primary_key=True)
295 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id'))
296 compressed_data: Column[bytes] = db.Column(db.LargeBinary)
298 @property
299 def data(self):
300 return decompress(self.compressed_data)
302 @data.setter
303 def data(self, value):
304 if value is not None:
305 self.compressed_data = compress(value)
308class EstimatorModel(db.Model):
309 """Simulation single estimator model"""
311 __tablename__ = 'Estimator'
312 id: Column[int] = db.Column(db.Integer, primary_key=True)
313 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id'), nullable=False)
314 name: Column[str] = db.Column(db.String, nullable=False, doc="Estimator name")
315 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata")
317 @property
318 def data(self):
319 return decompress(self.compressed_data)
321 @data.setter
322 def data(self, value):
323 if value is not None:
324 self.compressed_data = compress(value)
327class PageModel(db.Model):
328 """Estimator single page model"""
330 __tablename__ = 'Page'
331 id: Column[int] = db.Column(db.Integer, primary_key=True)
332 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id'), nullable=False)
333 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number")
334 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata")
336 @property
337 def data(self):
338 return decompress(self.compressed_data)
340 @data.setter
341 def data(self, value):
342 if value is not None:
343 self.compressed_data = compress(value)
346class LogfilesModel(db.Model):
347 """Simulation logfiles model"""
349 __tablename__ = 'Logfiles'
350 id: Column[int] = db.Column(db.Integer, primary_key=True)
351 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id'), nullable=False)
352 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles")
354 @property
355 def data(self):
356 return decompress(self.compressed_data)
358 @data.setter
359 def data(self, value):
360 if value is not None:
361 self.compressed_data = compress(value)
364def create_all():
365 """Creates all tables, to be used with Flask app context."""
366 db.create_all()