Coverage for yaptide/persistence/models.py: 94%

204 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-07-01 12:55 +0000

1import gzip 

2import json 

3from datetime import datetime 

4 

5from sqlalchemy import Column, UniqueConstraint 

6from sqlalchemy.orm import relationship 

7from sqlalchemy.sql.functions import now 

8from werkzeug.security import check_password_hash, generate_password_hash 

9 

10from yaptide.persistence.database import db 

11from yaptide.utils.enums import EntityState, PlatformType 

12 

13 

14class UserModel(db.Model): 

15 """User model""" 

16 

17 __tablename__ = 'User' 

18 id: Column[int] = db.Column(db.Integer, primary_key=True) 

19 username: Column[str] = db.Column(db.String, nullable=False) 

20 auth_provider: Column[str] = db.Column(db.String, nullable=False) 

21 simulations = relationship("SimulationModel") 

22 

23 __table_args__ = (UniqueConstraint('username', 'auth_provider', name='_username_provider_uc'), ) 

24 

25 __mapper_args__ = {"polymorphic_identity": "User", "polymorphic_on": auth_provider, "with_polymorphic": "*"} 

26 

27 def __repr__(self) -> str: 

28 return f'User #{self.id} {self.username}' 

29 

30 

31class YaptideUserModel(UserModel, db.Model): 

32 """Yaptide user model""" 

33 

34 __tablename__ = 'YaptideUser' 

35 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

36 password_hash: Column[str] = db.Column(db.String, nullable=False) 

37 

38 __mapper_args__ = {"polymorphic_identity": "YaptideUser", "polymorphic_load": "inline"} 

39 

40 def set_password(self, password: str): 

41 """Sets hashed password""" 

42 self.password_hash = generate_password_hash(password) 

43 

44 def check_password(self, password: str) -> bool: 

45 """Checks password correctness""" 

46 return check_password_hash(self.password_hash, password) 

47 

48 

49class KeycloakUserModel(UserModel, db.Model): 

50 """PLGrid user model""" 

51 

52 __tablename__ = 'KeycloakUser' 

53 id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id', ondelete="CASCADE"), primary_key=True) 

54 cert: Column[str] = db.Column(db.String, nullable=True) 

55 private_key: Column[str] = db.Column(db.String, nullable=True) 

56 

57 __mapper_args__ = {"polymorphic_identity": "KeycloakUser", "polymorphic_load": "inline"} 

58 

59 

60class ClusterModel(db.Model): 

61 """Cluster info for specific user""" 

62 

63 __tablename__ = 'Cluster' 

64 id: Column[int] = db.Column(db.Integer, primary_key=True) 

65 cluster_name: Column[str] = db.Column(db.String, nullable=False) 

66 simulations = relationship("BatchSimulationModel") 

67 

68 

69class SimulationModel(db.Model): 

70 """Simulation model""" 

71 

72 __tablename__ = 'Simulation' 

73 

74 id: Column[int] = db.Column(db.Integer, primary_key=True) 

75 

76 job_id: Column[str] = db.Column(db.String, nullable=False, unique=True, doc="Simulation job ID") 

77 

78 user_id: Column[int] = db.Column(db.Integer, db.ForeignKey('User.id'), doc="User ID") 

79 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), default=now(), doc="Submission time") 

80 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

81 nullable=True, 

82 doc="Job end time (including merging)") 

83 title: Column[str] = db.Column(db.String, nullable=False, doc="Job title") 

84 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

85 input_type: Column[str] = db.Column(db.String, 

86 nullable=False, 

87 doc="Input type (i.e. 'yaptide_project', 'input_files')") 

88 sim_type: Column[str] = db.Column(db.String, 

89 nullable=False, 

90 doc="Simulator type (i.e. 'shieldhit', 'topas', 'fluka')") 

91 job_state: Column[str] = db.Column(db.String, 

92 nullable=False, 

93 default=EntityState.UNKNOWN.value, 

94 doc="Simulation state (i.e. 'pending', 'running', 'completed', 'failed')") 

95 update_key_hash: Column[str] = db.Column(db.String, 

96 doc="Update key shared by tasks granting access to update themselves") 

97 tasks = relationship("TaskModel") 

98 estimators = relationship("EstimatorModel") 

99 

100 __mapper_args__ = {"polymorphic_identity": "Simulation", "polymorphic_on": platform, "with_polymorphic": "*"} 

101 

102 def set_update_key(self, update_key: str): 

103 """Sets hashed update key""" 

104 self.update_key_hash = generate_password_hash(update_key) 

105 

106 def check_update_key(self, update_key: str) -> bool: 

107 """Checks update key correctness""" 

108 return check_password_hash(self.update_key_hash, update_key) 

109 

110 def update_state(self, update_dict: dict) -> bool: 

111 """ 

112 Updating database is more costly than a simple query. 

113 Therefore we check first if update is needed and 

114 perform it only for such fields which exists and which have updated values. 

115 Returns bool value telling if it is required to commit changes to db. 

116 """ 

117 if self.job_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

118 return False 

119 db_commit_required = False 

120 if "job_state" in update_dict and self.job_state != update_dict["job_state"]: 

121 self.job_state = update_dict["job_state"] 

122 db_commit_required = True 

123 # Here we have a special case, `end_time` can be set only once 

124 # therefore we update it only if it not set previously (`self.end_time is None`) 

125 # and if update was requested (`"end_time" in update_dict`) 

126 if "end_time" in update_dict and self.end_time is None: 

127 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

128 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

129 db_commit_required = True 

130 return db_commit_required 

131 

132 

133class CelerySimulationModel(SimulationModel): 

134 """Celery simulation model""" 

135 

136 __tablename__ = 'CelerySimulation' 

137 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

138 merge_id: Column[str] = db.Column(db.String, nullable=True, doc="Celery collect job ID") 

139 

140 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

141 

142 

143class BatchSimulationModel(SimulationModel): 

144 """Batch simulation model""" 

145 

146 __tablename__ = 'BatchSimulation' 

147 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id', ondelete="CASCADE"), primary_key=True) 

148 cluster_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Cluster.id'), nullable=False, doc="Cluster ID") 

149 job_dir: Column[str] = db.Column(db.String, nullable=True, doc="Simulation folder name") 

150 array_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch array jon ID") 

151 collect_id: Column[int] = db.Column(db.Integer, nullable=True, doc="Batch collect job ID") 

152 

153 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

154 

155 

156class TaskModel(db.Model): 

157 """Simulation task model""" 

158 

159 __tablename__ = 'Task' 

160 id: Column[int] = db.Column(db.Integer, primary_key=True) 

161 simulation_id: Column[int] = db.Column(db.Integer, 

162 db.ForeignKey('Simulation.id'), 

163 doc="Simulation job ID (foreign key)") 

164 

165 task_id: Column[int] = db.Column(db.Integer, nullable=False, doc="Task ID") 

166 requested_primaries: Column[int] = db.Column(db.Integer, 

167 nullable=False, 

168 default=0, 

169 doc="Requested number of primaries") 

170 simulated_primaries: Column[int] = db.Column(db.Integer, 

171 nullable=False, 

172 default=0, 

173 doc="Simulated number of primaries") 

174 task_state: Column[str] = db.Column(db.String, 

175 nullable=False, 

176 default=EntityState.PENDING.value, 

177 doc="Task state (i.e. 'pending', 'running', 'completed', 'failed')") 

178 estimated_time: Column[int] = db.Column(db.Integer, nullable=True, doc="Estimated time in seconds") 

179 start_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task start time") 

180 end_time: Column[datetime] = db.Column(db.DateTime(timezone=True), nullable=True, doc="Task end time") 

181 platform: Column[str] = db.Column(db.String, nullable=False, doc="Execution platform name (i.e. 'direct', 'batch')") 

182 last_update_time: Column[datetime] = db.Column(db.DateTime(timezone=True), 

183 default=now(), 

184 doc="Task last update time") 

185 

186 __table_args__ = (UniqueConstraint('simulation_id', 'task_id', name='_simulation_id_task_id_uc'), ) 

187 

188 __mapper_args__ = {"polymorphic_identity": "Task", "polymorphic_on": platform, "with_polymorphic": "*"} 

189 

190 def update_state(self, update_dict: dict): 

191 """ 

192 Updating database is more costly than a simple query. 

193 Therefore we check first if update is needed and 

194 perform it only for such fields which exists and which have updated values. 

195 """ 

196 if self.task_state in (EntityState.COMPLETED.value, EntityState.FAILED.value, EntityState.CANCELED.value): 

197 return 

198 if "requested_primaries" in update_dict and self.requested_primaries != update_dict["requested_primaries"]: 

199 self.requested_primaries = update_dict["requested_primaries"] 

200 if "simulated_primaries" in update_dict and self.simulated_primaries != update_dict["simulated_primaries"]: 

201 self.simulated_primaries = update_dict["simulated_primaries"] 

202 if "task_state" in update_dict and self.task_state != update_dict["task_state"]: 

203 self.task_state = update_dict["task_state"] 

204 # Here we have a special case, `estimated_time` cannot be set when `end_time` is set - it is meaningless 

205 have_estim_time = "estimated_time" in update_dict and self.estimated_time != update_dict["estimated_time"] 

206 end_time_not_set = self.end_time is None 

207 if have_estim_time and end_time_not_set: 

208 self.estimated_time = update_dict["estimated_time"] 

209 if "start_time" in update_dict and self.start_time is None: 

210 # a convertion from string to datetime is needed, as in the POST payload start_time comes in string format 

211 self.start_time = datetime.strptime(update_dict["start_time"], '%Y-%m-%d %H:%M:%S.%f') 

212 # Here we have a special case, `end_time` can be set only once 

213 # therefore we update it only if it not set previously (`self.end_time is None`) 

214 # and if update was requested (`"end_time" in update_dict`) 

215 if "end_time" in update_dict and self.end_time is None: 

216 # a convertion from string to datetime is needed, as in the POST payload end_time comes in string format 

217 self.end_time = datetime.strptime(update_dict["end_time"], '%Y-%m-%d %H:%M:%S.%f') 

218 self.estimated_time = None 

219 self.last_update_time = now() 

220 

221 def get_status_dict(self) -> dict: 

222 """Returns task information as a dictionary""" 

223 result = { 

224 "task_state": self.task_state, 

225 "requested_primaries": self.requested_primaries, 

226 "simulated_primaries": self.simulated_primaries, 

227 "last_update_time": self.last_update_time, 

228 } 

229 if self.estimated_time: 

230 result["estimated_time"] = { 

231 "hours": self.estimated_time // 3600, 

232 "minutes": (self.estimated_time // 60) % 60, 

233 "seconds": self.estimated_time % 60, 

234 } 

235 if self.start_time: 

236 result["start_time"] = self.start_time 

237 if self.end_time: 

238 result["end_time"] = self.end_time 

239 return result 

240 

241 

242class CeleryTaskModel(TaskModel): 

243 """Celery task model""" 

244 

245 __tablename__ = 'CeleryTask' 

246 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

247 celery_id: Column[str] = db.Column(db.String, nullable=False, default="", doc="Celery task ID") 

248 

249 def update_state(self, update_dict: dict): 

250 """Update method for CeleryTaskModel""" 

251 if "celery_id" in update_dict and self.celery_id != update_dict["celery_id"]: 

252 self.celery_id = update_dict["celery_id"] 

253 return super().update_state(update_dict) 

254 

255 __mapper_args__ = {"polymorphic_identity": PlatformType.DIRECT.value, "polymorphic_load": "inline"} 

256 

257 

258class BatchTaskModel(TaskModel): 

259 """Batch task model""" 

260 

261 __tablename__ = 'BatchTask' 

262 id: Column[int] = db.Column(db.Integer, db.ForeignKey('Task.id', ondelete="CASCADE"), primary_key=True) 

263 

264 __mapper_args__ = {"polymorphic_identity": PlatformType.BATCH.value, "polymorphic_load": "inline"} 

265 

266 

267def decompress(data: bytes): 

268 """Decompresses data and deserializes JSON""" 

269 data_to_unpack: str = 'null' 

270 if data is not None: 

271 # Decompress the data 

272 decompressed_bytes: bytes = gzip.decompress(data) 

273 data_to_unpack = decompressed_bytes.decode('utf-8') 

274 # Deserialize the JSON 

275 return json.loads(data_to_unpack) 

276 

277 

278def compress(data) -> bytes: 

279 """Serializes JSON and compresses data""" 

280 compressed_bytes = b'' 

281 if data is not None: 

282 # Serialize the JSON 

283 serialized_data: str = json.dumps(data) 

284 # Compress the data 

285 bytes_to_compress: bytes = serialized_data.encode('utf-8') 

286 compressed_bytes = gzip.compress(bytes_to_compress) 

287 return compressed_bytes 

288 

289 

290class InputModel(db.Model): 

291 """Simulation inputs model""" 

292 

293 __tablename__ = 'Input' 

294 id: Column[int] = db.Column(db.Integer, primary_key=True) 

295 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id')) 

296 compressed_data: Column[bytes] = db.Column(db.LargeBinary) 

297 

298 @property 

299 def data(self): 

300 return decompress(self.compressed_data) 

301 

302 @data.setter 

303 def data(self, value): 

304 if value is not None: 

305 self.compressed_data = compress(value) 

306 

307 

308class EstimatorModel(db.Model): 

309 """Simulation single estimator model""" 

310 

311 __tablename__ = 'Estimator' 

312 id: Column[int] = db.Column(db.Integer, primary_key=True) 

313 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id'), nullable=False) 

314 name: Column[str] = db.Column(db.String, nullable=False, doc="Estimator name") 

315 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Estimator metadata") 

316 

317 @property 

318 def data(self): 

319 return decompress(self.compressed_data) 

320 

321 @data.setter 

322 def data(self, value): 

323 if value is not None: 

324 self.compressed_data = compress(value) 

325 

326 

327class PageModel(db.Model): 

328 """Estimator single page model""" 

329 

330 __tablename__ = 'Page' 

331 id: Column[int] = db.Column(db.Integer, primary_key=True) 

332 estimator_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Estimator.id'), nullable=False) 

333 page_number: Column[int] = db.Column(db.Integer, nullable=False, doc="Page number") 

334 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Page json object - data, axes and metadata") 

335 

336 @property 

337 def data(self): 

338 return decompress(self.compressed_data) 

339 

340 @data.setter 

341 def data(self, value): 

342 if value is not None: 

343 self.compressed_data = compress(value) 

344 

345 

346class LogfilesModel(db.Model): 

347 """Simulation logfiles model""" 

348 

349 __tablename__ = 'Logfiles' 

350 id: Column[int] = db.Column(db.Integer, primary_key=True) 

351 simulation_id: Column[int] = db.Column(db.Integer, db.ForeignKey('Simulation.id'), nullable=False) 

352 compressed_data: Column[bytes] = db.Column(db.LargeBinary, doc="Json object containing logfiles") 

353 

354 @property 

355 def data(self): 

356 return decompress(self.compressed_data) 

357 

358 @data.setter 

359 def data(self, value): 

360 if value is not None: 

361 self.compressed_data = compress(value) 

362 

363 

364def create_all(): 

365 """Creates all tables, to be used with Flask app context.""" 

366 db.create_all()