Coverage for yaptide/persistence/models.py: 90%

223 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-06-10 10:08 +0000

1# ---------- IMPORTANT ------------ 

2# Read documentation in persistency.md. It contains information about database development with flask-migrate. 

3 

4import gzip 

5import json 

6from datetime import datetime 

7 

8from sqlalchemy import Column, UniqueConstraint 

9from sqlalchemy.orm import relationship 

10from sqlalchemy.sql.functions import now 

11from werkzeug.security import check_password_hash, generate_password_hash 

12 

13from yaptide.persistence.database import db 

14from yaptide.utils.enums import EntityState, PlatformType 

15 

16 

17class UserModel(db.Model): 

18 """User model""" 

19 

20 __tablename__ = 'User' 

21 id: Column[int] = db.Column(db.Integer, primary_key=True) 

22 username: Column[str] = db.Column(db.String, nullable=False) 

23 auth_provider: Column[str] = db.Column(db.String, nullable=False) 

24 simulations = relationship("SimulationModel") 

25 

26 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), ) 

27 

28 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"} 

29 

30 def __repr__(self) -> str: 

31 return f'User #{self.id} {self.username}' 

32 

33 

34class YaptideUserModel(UserModel, db.Model): 

35 """Yaptide user model""" 

36 

37 __tablename__ = 'YaptideUser' 

38 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

39 password_hash: Column[str] = db.Column(db.String, nullable=False) 

40 

41 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"} 

42 

43 def set_password(self, password: str): 

44 """Sets hashed password""" 

45 self.password_hash = generate_password_hash(password) 

46 

47 def check_password(self, password: str) -> bool: 

48 """Checks password correctness""" 

49 return check_password_hash(self.password_hash, password) 

50 

51 

52class KeycloakUserModel(UserModel, db.Model): 

53 """PLGrid user model""" 

54 

55 __tablename__ = 'KeycloakUser' 

56 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

57 cert: Column[str] = db.Column(db.String, nullable=True) 

58 private_key: Column[str] = db.Column(db.String, nullable=True) 

59 

60 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"} 

61 

62 

63class ClusterModel(db.Model): 

64 """Cluster info for specific user""" 

65 

66 __tablename__ = 'Cluster' 

67 id: Column[int] = db.Column(db.Integer, primary_key=True) 

68 cluster_name: Column[str] = db.Column(db.String, nullable=False) 

69 simulations = relationship("BatchSimulationModel") 

70 

71 

72class SimulationModel(db.Model): 

73 """Simulation model""" 

74 

75 __tablename__ = 'Simulation' 

76 

77 id: Column[int] = db.Column(db.Integer, primary_key=True) 

78 

79 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID") 

80 

81 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID") 

82 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time") 

83 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

84 nullable=True, 

85 doc="Job end time (including merging)") 

86 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title") 

87 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

88 input_type: Column[str] = db.Column(db.String, 

89 nullable=False, 

90 doc="Input type (i.e. 'yaptide_project', 'input_files')") 

91 sim_type: Column[str] = db.Column(db.String, 

92 nullable=False, 

93 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')") 

94 job_state: Column[str] = db.Column(db.String, 

95 nullable=False, 

96 default=EntityState.UNKNOWN.value, 

97 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')") 

98 

99 tasks = relationship("TaskModel", cascade="delete") 

100 estimators = relationship("EstimatorModel", cascade="delete") 

101 inputs = relationship("InputModel", cascade="delete") 

102 logfiles = relationship("LogfilesModel", cascade="delete") 

103 

104 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"} 

105 

106 def update_state(self, update_dict: dict) -> bool: 

107 """ 

108 Updating database is more costly than a simple query. 

109 Therefore we check first if update is needed and 

110 perform it only for such fields which exists and which have updated values. 

111 Returns bool value telling if it is required to commit changes to db. 

112 """ 

113 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

114 return False 

115 db_commit_required = False 

116 if "job_state" in update_dict and self.job_state != update_dict["job_state"]: 

117 self.job_state = update_dict["job_state"] 

118 db_commit_required = True 

119 # Here we have a special case, `end_time` can be set only once 

120 # therefore we update it only if it not set previously (`self.end_time is None`) 

121 # and if update was requested (`"end_time" in update_dict`) 

122 if "end_time" in update_dict and self.end_time is None: 

123 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

124 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

125 db_commit_required = True 

126 return db_commit_required 

127 

128 

129class CelerySimulationModel(SimulationModel): 

130 """Celery simulation model""" 

131 

132 __tablename__ = 'CelerySimulation' 

133 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

134 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID") 

135 

136 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

137 

138 

139class BatchSimulationModel(SimulationModel): 

140 """Batch simulation model""" 

141 

142 __tablename__ = 'BatchSimulation' 

143 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

144 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID") 

145 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name") 

146 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID") 

147 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID") 

148 

149 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

150 

151 def update_state(self, update_dict): 

152 """Used to update fields in BatchSimulation. Returns boolean value if commit to database is reuqired""" 

153 db_commit_required = super().update_state(update_dict) 

154 if "job_dir" in update_dict and self.job_dir != update_dict["job_dir"]: 

155 self.job_dir = update_dict["job_dir"] 

156 db_commit_required = True 

157 if "array_id" in update_dict and self.array_id != update_dict["array_id"]: 

158 self.array_id = update_dict["array_id"] 

159 db_commit_required = True 

160 if "collect_id" in update_dict and self.collect_id != update_dict["collect_id"]: 

161 self.collect_id = update_dict["collect_id"] 

162 db_commit_required = True 

163 return db_commit_required 

164 

165 

166def allowed_state_change(current_state: str, next_state: str): 

167 """Ensures that no such change like Completed -> Canceled happens""" 

168 return not (current_state in [EntityState.FAILED.value, EntityState.COMPLETED.value] 

169 and next_state in [EntityState.CANCELED]) 

170 

171 

172def value_changed(current_value: str, new_value: str): 

173 """checks if value from update_dict differs from object in database""" 

174 return new_value and current_value != new_value 

175 

176 

177class TaskModel(db.Model): 

178 """Simulation task model""" 

179 

180 __tablename__ = 'Task' 

181 id: Column[int] = db.Column(db.Integer, primary_key=True) 

182 simulation_id: Column[int] = db.Column(db.Integer, 

183 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

184 doc="Simulation job ID (foreign key)") 

185 

186 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID") 

187 requested_primaries: Column[int] = db.Column(db.Integer, 

188 nullable=False, 

189 default=0, 

190 doc="Requested number of primaries") 

191 simulated_primaries: Column[int] = db.Column(db.Integer, 

192 nullable=False, 

193 default=0, 

194 doc="Simulated number of primaries") 

195 task_state: Column[str] = db.Column(db.String, 

196 nullable=False, 

197 default=EntityState.PENDING.value, 

198 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')") 

199 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds") 

200 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time") 

201 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time") 

202 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

203 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

204 default=now(), 

205 doc="Task last update time") 

206 

207 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), ) 

208 

209 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"} 

210 

211 def update_state(self, update_dict: dict): # skipcq: PY-R1000 

212 """ 

213 Updating database is more costly than a simple query. 

214 Therefore we check first if update is needed and 

215 perform it only for such fields which exists and which have updated values. 

216 """ 

217 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

218 return 

219 if value_changed(self.requested_primaries, update_dict.get("requested_primaries")): 

220 self.requested_primaries = update_dict["requested_primaries"] 

221 if value_changed(self.simulated_primaries, update_dict.get("simulated_primaries")): 

222 self.simulated_primaries = update_dict["simulated_primaries"] 

223 if value_changed(self.task_state, update_dict.get("task_state")) and allowed_state_change( 

224 self.task_state, update_dict["task_state"]): 

225 self.task_state = update_dict["task_state"] 

226 if self.task_state == EntityState.COMPLETED.value: 

227 self.simulated_primaries = self.requested_primaries 

228 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless 

229 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"] 

230 end_time_not_set = self.end_time is None 

231 if have_estim_time and end_time_not_set: 

232 self.estimated_time = update_dict["estimated_time"] 

233 if "start_time" in update_dict and self.start_time is None: 

234 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format 

235 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f') 

236 # Here we have a special case, `end_time` can be set only once 

237 # therefore we update it only if it not set previously (`self.end_time is None`) 

238 # and if update was requested (`"end_time" in update_dict`) 

239 if "end_time" in update_dict and self.end_time is None: 

240 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

241 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

242 self.estimated_time = None 

243 self.last_update_time = now() 

244 

245 def get_status_dict(self) -> dict: 

246 """Returns task information as a dictionary""" 

247 result = { 

248 "task_state": self.task_state, 

249 "requested_primaries": self.requested_primaries, 

250 "simulated_primaries": self.simulated_primaries, 

251 "last_update_time": self.last_update_time, 

252 "task_id": self.id 

253 } 

254 if self.estimated_time: 

255 result["estimated_time"] = { 

256 "hours": self.estimated_time // 3600, 

257 "minutes": (self.estimated_time // 60) % 60, 

258 "seconds": self.estimated_time % 60, 

259 } 

260 if self.start_time: 

261 result["start_time"] = self.start_time 

262 if self.end_time: 

263 result["end_time"] = self.end_time 

264 return result 

265 

266 

267class CeleryTaskModel(TaskModel): 

268 """Celery task model""" 

269 

270 __tablename__ = 'CeleryTask' 

271 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

272 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID") 

273 

274 def update_state(self, update_dict: dict): 

275 """Update method for CeleryTaskModel""" 

276 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]: 

277 self.celery_id = update_dict["celery_id"] 

278 return super().update_state(update_dict) 

279 

280 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

281 

282 

283class BatchTaskModel(TaskModel): 

284 """Batch task model""" 

285 

286 __tablename__ = 'BatchTask' 

287 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

288 

289 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

290 

291 

292def decompress(data: bytes): 

293 """Decompresses data and deserializes JSON""" 

294 data_to_unpack: str = 'null' 

295 if data is not None: 

296 # Decompress the data 

297 decompressed_bytes: bytes = gzip.decompress(data) 

298 data_to_unpack = decompressed_bytes.decode('utf-8') 

299 # Deserialize the JSON 

300 return json.loads(data_to_unpack) 

301 

302 

303def compress(data) -> bytes: 

304 """Serializes JSON and compresses data""" 

305 compressed_bytes = b'' 

306 if data is not None: 

307 # Serialize the JSON 

308 serialized_data: str = json.dumps(data) 

309 # Compress the data 

310 bytes_to_compress: bytes = serialized_data.encode('utf-8') 

311 compressed_bytes = gzip.compress(bytes_to_compress) 

312 return compressed_bytes 

313 

314 

315class InputModel(db.Model): 

316 """Simulation inputs model""" 

317 

318 __tablename__ = 'Input' 

319 id: Column[int] = db.Column(db.Integer, primary_key=True) 

320 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE")) 

321 compressed_data: Column[bytes] = db.Column(db.LargeBinary) 

322 

323 @property 

324 def data(self): 

325 return decompress(self.compressed_data) 

326 

327 @data.setter 

328 def data(self, value): 

329 if value is not None: 

330 self.compressed_data = compress(value) 

331 

332 

333class EstimatorModel(db.Model): 

334 """Simulation single estimator model""" 

335 

336 __tablename__ = 'Estimator' 

337 id: Column[int] = db.Column(db.Integer, primary_key=True) 

338 simulation_id: Column[int] = db.Column(db.Integer, 

339 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

340 nullable=False) 

341 name: Column[str] = db.Column(db.String, nullable=False, doc="Human readable estimator name") 

342 file_name: Column[str] = db.Column(db.String, 

343 nullable=False, 

344 doc="Estimator name extracted from file generated by simulator") 

345 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata") 

346 pages = relationship("PageModel", cascade="delete") 

347 

348 @property 

349 def data(self): 

350 return decompress(self.compressed_data) 

351 

352 @data.setter 

353 def data(self, value): 

354 if value is not None: 

355 self.compressed_data = compress(value) 

356 

357 

358class PageModel(db.Model): 

359 """Estimator single page model""" 

360 

361 __tablename__ = 'Page' 

362 id: Column[int] = db.Column(db.Integer, primary_key=True) 

363 page_name: Column[str] = db.Column(db.String, nullable=False, doc="Page name") 

364 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id', ondelete="CASCADE"), nullable=False) 

365 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number") 

366 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata") 

367 page_dimension: Column[int] = db.Column(db.Integer, nullable=False, doc="Dimension of data") 

368 

369 @property 

370 def data(self): 

371 return decompress(self.compressed_data) 

372 

373 @data.setter 

374 def data(self, value): 

375 if value is not None: 

376 self.compressed_data = compress(value) 

377 

378 

379class LogfilesModel(db.Model): 

380 """Simulation logfiles model""" 

381 

382 __tablename__ = 'Logfiles' 

383 id: Column[int] = db.Column(db.Integer, primary_key=True) 

384 simulation_id: Column[int] = db.Column(db.Integer, 

385 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

386 nullable=False) 

387 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles") 

388 

389 @property 

390 def data(self): 

391 return decompress(self.compressed_data) 

392 

393 @data.setter 

394 def data(self, value): 

395 if value is not None: 

396 self.compressed_data = compress(value) 

397 

398 

399def create_all(): 

400 """Creates all tables, to be used with Flask app context.""" 

401 db.create_all()