Coverage for yaptide/persistence/models.py: 90%

220 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1# ---------- IMPORTANT ------------ 

2# Read documentation in persistency.md. It contains information about database development with flask-migrate. 

3 

4import gzip 

5import json 

6from datetime import datetime 

7 

8from sqlalchemy import Column, UniqueConstraint 

9from sqlalchemy.orm import relationship 

10from sqlalchemy.sql.functions import now 

11from werkzeug.security import check_password_hash, generate_password_hash 

12 

13from yaptide.persistence.database import db 

14from yaptide.utils.enums import EntityState, PlatformType 

15 

16 

17class UserModel(db.Model): 

18 """User model""" 

19 

20 __tablename__ = 'User' 

21 id: Column[int] = db.Column(db.Integer, primary_key=True) 

22 username: Column[str] = db.Column(db.String, nullable=False) 

23 auth_provider: Column[str] = db.Column(db.String, nullable=False) 

24 simulations = relationship("SimulationModel") 

25 

26 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), ) 

27 

28 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"} 

29 

30 def __repr__(self) -> str: 

31 return f'User #{self.id} {self.username}' 

32 

33 

34class YaptideUserModel(UserModel, db.Model): 

35 """Yaptide user model""" 

36 

37 __tablename__ = 'YaptideUser' 

38 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

39 password_hash: Column[str] = db.Column(db.String, nullable=False) 

40 

41 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"} 

42 

43 def set_password(self, password: str): 

44 """Sets hashed password""" 

45 self.password_hash = generate_password_hash(password) 

46 

47 def check_password(self, password: str) -> bool: 

48 """Checks password correctness""" 

49 return check_password_hash(self.password_hash, password) 

50 

51 

52class KeycloakUserModel(UserModel, db.Model): 

53 """PLGrid user model""" 

54 

55 __tablename__ = 'KeycloakUser' 

56 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

57 cert: Column[str] = db.Column(db.String, nullable=True) 

58 private_key: Column[str] = db.Column(db.String, nullable=True) 

59 

60 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"} 

61 

62 

63class ClusterModel(db.Model): 

64 """Cluster info for specific user""" 

65 

66 __tablename__ = 'Cluster' 

67 id: Column[int] = db.Column(db.Integer, primary_key=True) 

68 cluster_name: Column[str] = db.Column(db.String, nullable=False) 

69 simulations = relationship("BatchSimulationModel") 

70 

71 

72class SimulationModel(db.Model): 

73 """Simulation model""" 

74 

75 __tablename__ = 'Simulation' 

76 

77 id: Column[int] = db.Column(db.Integer, primary_key=True) 

78 

79 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID") 

80 

81 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID") 

82 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time") 

83 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

84 nullable=True, 

85 doc="Job end time (including merging)") 

86 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title") 

87 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

88 input_type: Column[str] = db.Column(db.String, 

89 nullable=False, 

90 doc="Input type (i.e. 'yaptide_project', 'input_files')") 

91 sim_type: Column[str] = db.Column(db.String, 

92 nullable=False, 

93 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')") 

94 job_state: Column[str] = db.Column(db.String, 

95 nullable=False, 

96 default=EntityState.UNKNOWN.value, 

97 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')") 

98 

99 tasks = relationship("TaskModel", cascade="delete") 

100 estimators = relationship("EstimatorModel", cascade="delete") 

101 inputs = relationship("InputModel", cascade="delete") 

102 logfiles = relationship("LogfilesModel", cascade="delete") 

103 

104 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"} 

105 

106 def update_state(self, update_dict: dict) -> bool: 

107 """ 

108 Updating database is more costly than a simple query. 

109 Therefore we check first if update is needed and 

110 perform it only for such fields which exists and which have updated values. 

111 Returns bool value telling if it is required to commit changes to db. 

112 """ 

113 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

114 return False 

115 db_commit_required = False 

116 if "job_state" in update_dict and self.job_state != update_dict["job_state"]: 

117 self.job_state = update_dict["job_state"] 

118 db_commit_required = True 

119 # Here we have a special case, `end_time` can be set only once 

120 # therefore we update it only if it not set previously (`self.end_time is None`) 

121 # and if update was requested (`"end_time" in update_dict`) 

122 if "end_time" in update_dict and self.end_time is None: 

123 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

124 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

125 db_commit_required = True 

126 return db_commit_required 

127 

128 

129class CelerySimulationModel(SimulationModel): 

130 """Celery simulation model""" 

131 

132 __tablename__ = 'CelerySimulation' 

133 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

134 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID") 

135 

136 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

137 

138 

139class BatchSimulationModel(SimulationModel): 

140 """Batch simulation model""" 

141 

142 __tablename__ = 'BatchSimulation' 

143 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

144 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID") 

145 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name") 

146 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID") 

147 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID") 

148 

149 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

150 

151 def update_state(self, update_dict): 

152 """Used to update fields in BatchSimulation. Returns boolean value if commit to database is reuqired""" 

153 db_commit_required = super().update_state(update_dict) 

154 if "job_dir" in update_dict and self.job_dir != update_dict["job_dir"]: 

155 self.job_dir = update_dict["job_dir"] 

156 db_commit_required = True 

157 if "array_id" in update_dict and self.array_id != update_dict["array_id"]: 

158 self.array_id = update_dict["array_id"] 

159 db_commit_required = True 

160 if "collect_id" in update_dict and self.collect_id != update_dict["collect_id"]: 

161 self.collect_id = update_dict["collect_id"] 

162 db_commit_required = True 

163 return db_commit_required 

164 

165 

166def allowed_state_change(current_state: str, next_state: str): 

167 """Ensures that no such change like Completed -> Canceled happens""" 

168 return not (current_state in [EntityState.FAILED.value, EntityState.COMPLETED.value] 

169 and next_state in [EntityState.CANCELED]) 

170 

171 

172def value_changed(current_value: str, new_value: str): 

173 """checks if value from update_dict differs from object in database""" 

174 return new_value and current_value != new_value 

175 

176 

177class TaskModel(db.Model): 

178 """Simulation task model""" 

179 

180 __tablename__ = 'Task' 

181 id: Column[int] = db.Column(db.Integer, primary_key=True) 

182 simulation_id: Column[int] = db.Column(db.Integer, 

183 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

184 doc="Simulation job ID (foreign key)") 

185 

186 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID") 

187 requested_primaries: Column[int] = db.Column(db.Integer, 

188 nullable=False, 

189 default=0, 

190 doc="Requested number of primaries") 

191 simulated_primaries: Column[int] = db.Column(db.Integer, 

192 nullable=False, 

193 default=0, 

194 doc="Simulated number of primaries") 

195 task_state: Column[str] = db.Column(db.String, 

196 nullable=False, 

197 default=EntityState.PENDING.value, 

198 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')") 

199 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds") 

200 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time") 

201 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time") 

202 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

203 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

204 default=now(), 

205 doc="Task last update time") 

206 

207 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), ) 

208 

209 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"} 

210 

211 def update_state(self, update_dict: dict): # skipcq: PY-R1000 

212 """ 

213 Updating database is more costly than a simple query. 

214 Therefore we check first if update is needed and 

215 perform it only for such fields which exists and which have updated values. 

216 """ 

217 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

218 return 

219 if value_changed(self.requested_primaries, update_dict.get("requested_primaries")): 

220 self.requested_primaries = update_dict["requested_primaries"] 

221 if value_changed(self.simulated_primaries, update_dict.get("simulated_primaries")): 

222 self.simulated_primaries = update_dict["simulated_primaries"] 

223 if value_changed(self.task_state, update_dict.get("task_state")) and allowed_state_change( 

224 self.task_state, update_dict["task_state"]): 

225 self.task_state = update_dict["task_state"] 

226 if self.task_state == EntityState.COMPLETED.value: 

227 self.simulated_primaries = self.requested_primaries 

228 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless 

229 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"] 

230 end_time_not_set = self.end_time is None 

231 if have_estim_time and end_time_not_set: 

232 self.estimated_time = update_dict["estimated_time"] 

233 if "start_time" in update_dict and self.start_time is None: 

234 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format 

235 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f') 

236 # Here we have a special case, `end_time` can be set only once 

237 # therefore we update it only if it not set previously (`self.end_time is None`) 

238 # and if update was requested (`"end_time" in update_dict`) 

239 if "end_time" in update_dict and self.end_time is None: 

240 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

241 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

242 self.estimated_time = None 

243 self.last_update_time = now() 

244 

245 def get_status_dict(self) -> dict: 

246 """Returns task information as a dictionary""" 

247 result = { 

248 "task_state": self.task_state, 

249 "requested_primaries": self.requested_primaries, 

250 "simulated_primaries": self.simulated_primaries, 

251 "last_update_time": self.last_update_time, 

252 } 

253 if self.estimated_time: 

254 result["estimated_time"] = { 

255 "hours": self.estimated_time // 3600, 

256 "minutes": (self.estimated_time // 60) % 60, 

257 "seconds": self.estimated_time % 60, 

258 } 

259 if self.start_time: 

260 result["start_time"] = self.start_time 

261 if self.end_time: 

262 result["end_time"] = self.end_time 

263 return result 

264 

265 

266class CeleryTaskModel(TaskModel): 

267 """Celery task model""" 

268 

269 __tablename__ = 'CeleryTask' 

270 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

271 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID") 

272 

273 def update_state(self, update_dict: dict): 

274 """Update method for CeleryTaskModel""" 

275 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]: 

276 self.celery_id = update_dict["celery_id"] 

277 return super().update_state(update_dict) 

278 

279 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

280 

281 

282class BatchTaskModel(TaskModel): 

283 """Batch task model""" 

284 

285 __tablename__ = 'BatchTask' 

286 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

287 

288 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

289 

290 

291def decompress(data: bytes): 

292 """Decompresses data and deserializes JSON""" 

293 data_to_unpack: str = 'null' 

294 if data is not None: 

295 # Decompress the data 

296 decompressed_bytes: bytes = gzip.decompress(data) 

297 data_to_unpack = decompressed_bytes.decode('utf-8') 

298 # Deserialize the JSON 

299 return json.loads(data_to_unpack) 

300 

301 

302def compress(data) -> bytes: 

303 """Serializes JSON and compresses data""" 

304 compressed_bytes = b'' 

305 if data is not None: 

306 # Serialize the JSON 

307 serialized_data: str = json.dumps(data) 

308 # Compress the data 

309 bytes_to_compress: bytes = serialized_data.encode('utf-8') 

310 compressed_bytes = gzip.compress(bytes_to_compress) 

311 return compressed_bytes 

312 

313 

314class InputModel(db.Model): 

315 """Simulation inputs model""" 

316 

317 __tablename__ = 'Input' 

318 id: Column[int] = db.Column(db.Integer, primary_key=True) 

319 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE")) 

320 compressed_data: Column[bytes] = db.Column(db.LargeBinary) 

321 

322 @property 

323 def data(self): 

324 return decompress(self.compressed_data) 

325 

326 @data.setter 

327 def data(self, value): 

328 if value is not None: 

329 self.compressed_data = compress(value) 

330 

331 

332class EstimatorModel(db.Model): 

333 """Simulation single estimator model""" 

334 

335 __tablename__ = 'Estimator' 

336 id: Column[int] = db.Column(db.Integer, primary_key=True) 

337 simulation_id: Column[int] = db.Column(db.Integer, 

338 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

339 nullable=False) 

340 name: Column[str] = db.Column(db.String, nullable=False, doc="Estimator name") 

341 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata") 

342 pages = relationship("PageModel", cascade="delete") 

343 

344 @property 

345 def data(self): 

346 return decompress(self.compressed_data) 

347 

348 @data.setter 

349 def data(self, value): 

350 if value is not None: 

351 self.compressed_data = compress(value) 

352 

353 

354class PageModel(db.Model): 

355 """Estimator single page model""" 

356 

357 __tablename__ = 'Page' 

358 id: Column[int] = db.Column(db.Integer, primary_key=True) 

359 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id', ondelete="CASCADE"), nullable=False) 

360 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number") 

361 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata") 

362 

363 @property 

364 def data(self): 

365 return decompress(self.compressed_data) 

366 

367 @data.setter 

368 def data(self, value): 

369 if value is not None: 

370 self.compressed_data = compress(value) 

371 

372 

373class LogfilesModel(db.Model): 

374 """Simulation logfiles model""" 

375 

376 __tablename__ = 'Logfiles' 

377 id: Column[int] = db.Column(db.Integer, primary_key=True) 

378 simulation_id: Column[int] = db.Column(db.Integer, 

379 db.ForeignKey('Simulation.id', ondelete="CASCADE"), 

380 nullable=False) 

381 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles") 

382 

383 @property 

384 def data(self): 

385 return decompress(self.compressed_data) 

386 

387 @data.setter 

388 def data(self, value): 

389 if value is not None: 

390 self.compressed_data = compress(value) 

391 

392 

393def create_all(): 

394 """Creates all tables, to be used with Flask app context.""" 

395 db.create_all()