Coverage for yaptide/admin/simulator_storage.py: 18%

266 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-04 00:31 +0000

1import platform 

2import shutil 

3import tarfile 

4import tempfile 

5import zipfile 

6from base64 import urlsafe_b64encode 

7from enum import IntEnum, auto 

8from pathlib import Path 

9 

10import boto3 

11import click 

12import cryptography 

13import requests 

14from botocore.exceptions import (ClientError, EndpointConnectionError, NoCredentialsError) 

15from cryptography.fernet import Fernet 

16from cryptography.hazmat.primitives import hashes 

17from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC 

18 

19 

20class SimulatorType(IntEnum): 

21 """Simulator types""" 

22 

23 shieldhit = auto() 

24 fluka = auto() 

25 topas = auto() 

26 

27 

28def extract_shieldhit_from_tar_gz(archive_path: Path, unpacking_directory: Path, member_name: str, 

29 destination_dir: Path): 

30 """Extracts a single file from a tar.gz archive""" 

31 with tarfile.open(archive_path, "r:gz") as tar: 

32 # print all members 

33 for member in tar.getmembers(): 

34 if Path(member.name).name == member_name and Path(member.name).parent.name == 'bin': 

35 click.echo(f"Extracting {member.name}") 

36 tar.extract(member, unpacking_directory) 

37 # move to installation path 

38 local_file_path = unpacking_directory / member.name 

39 click.echo(f"Moving {local_file_path} to {destination_dir}") 

40 shutil.move(local_file_path, destination_dir / member_name) 

41 

42 

43def extract_shieldhit_from_zip(archive_path: Path, unpacking_dir: Path, member_name: str, destination_dir: Path): 

44 """Extracts a single file from a zip archive""" 

45 with zipfile.ZipFile(archive_path) as zip_handle: 

46 # print all members 

47 for member in zip_handle.infolist(): 

48 click.echo(f"Member: {member.filename}") 

49 if Path(member.filename).name == member_name: 

50 click.echo(f"Extracting {member.filename}") 

51 zip_handle.extract(member, unpacking_dir) 

52 # move to installation path 

53 local_file_path = Path(unpacking_dir) / member.filename 

54 destination_file_path = destination_dir / member_name 

55 click.echo(f"Moving {local_file_path} to {destination_file_path}") 

56 # move file from temporary directory to installation path using shutils 

57 if not destination_file_path.exists(): 

58 shutil.move(local_file_path, destination_file_path) 

59 

60 

61def download_shieldhit_demo_version(destination_dir: Path) -> bool: 

62 """Download shieldhit demo version from shieldhit.org""" 

63 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_x86_64_demo_gfortran_v1.1.0.tar.gz' 

64 # check if working on Windows 

65 if platform.system() == 'Windows': 

66 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_win64_demo_v1.1.0.zip' 

67 

68 # create temporary directory and download 

69 # Create a temporary file to store the downloaded binary data 

70 with tempfile.TemporaryDirectory() as tmpdir_name: 

71 click.echo(f"Downloading from {demo_version_url} to {tmpdir_name}") 

72 headers = {'User-Agent': 'Mozilla/5.0 (Windows NT x.y; rv:10.0) Gecko/20100101 Firefox/10.0'} 

73 response = requests.get(demo_version_url, headers=headers) 

74 temp_file_archive = Path(tmpdir_name) / Path(demo_version_url).name 

75 with open(temp_file_archive, 'wb') as file_handle: 

76 file_handle.write(response.content) 

77 click.echo(f"Saved to {temp_file_archive} with size {temp_file_archive.stat().st_size} bytes") 

78 

79 # extract 

80 click.echo(f"Extracting {temp_file_archive} to {destination_dir}") 

81 destination_dir.mkdir(parents=True, exist_ok=True) 

82 if temp_file_archive.suffix == '.gz': 

83 extract_shieldhit_from_tar_gz(temp_file_archive, 

84 Path(tmpdir_name), 

85 'shieldhit', 

86 destination_dir=destination_dir) 

87 elif temp_file_archive.suffix == '.zip': 

88 extract_shieldhit_from_zip(temp_file_archive, 

89 Path(tmpdir_name), 

90 'shieldhit.exe', 

91 destination_dir=destination_dir) 

92 return True 

93 

94 

95def check_if_s3_connection_is_working(s3_client: boto3.client) -> bool: 

96 """Check if connection to S3 is possible""" 

97 try: 

98 s3_client.list_buckets() 

99 except NoCredentialsError as e: 

100 click.echo(f"No credentials found. Check your access key and secret key. {e}", err=True) 

101 return False 

102 except EndpointConnectionError as e: 

103 click.echo(f"Could not connect to the specified endpoint. {e}", err=True) 

104 return False 

105 except ClientError as e: 

106 click.echo(f"An error occurred while connecting to S3: {e.response['Error']['Message']}", err=True) 

107 return False 

108 return True 

109 

110 

111def download_shieldhit_from_s3( 

112 destination_dir: Path, 

113 endpoint: str, 

114 access_key: str, 

115 secret_key: str, 

116 password: str, 

117 salt: str, 

118 bucket: str, 

119 key: str, 

120 decrypt: bool = True, 

121) -> bool: 

122 """Download SHIELD-HIT12A from S3 bucket""" 

123 s3_client = boto3.client("s3", 

124 aws_access_key_id=access_key, 

125 aws_secret_access_key=secret_key, 

126 endpoint_url=endpoint) 

127 

128 if not validate_connection_data(bucket=bucket, key=key, s3_client=s3_client): 

129 return False 

130 

131 if not destination_dir.exists(): 

132 destination_dir.mkdir(parents=True, exist_ok=True) 

133 

134 destination_file_path = destination_dir / 'shieldhit' 

135 # append '.exe' to file name if working on Windows 

136 if platform.system() == 'Windows': 

137 destination_file_path = destination_dir / 'shieldhit.exe' 

138 

139 download_and_decrypt_status = download_file(key=key, 

140 bucket=bucket, 

141 s3_client=s3_client, 

142 decrypt=decrypt, 

143 password=password, 

144 salt=salt, 

145 destination_file_path=destination_file_path) 

146 

147 if not download_and_decrypt_status: 

148 return False 

149 

150 return True 

151 

152 

153def download_shieldhit_from_s3_or_from_website( 

154 destination_dir: Path, 

155 endpoint: str, 

156 access_key: str, 

157 secret_key: str, 

158 password: str, 

159 salt: str, 

160 bucket: str, 

161 key: str, 

162 decrypt: bool = True, 

163): 

164 """Download SHIELD-HIT12A from S3 bucket, if not available download demo version from shieldhit.org website""" 

165 download_ok = download_shieldhit_from_s3(destination_dir=destination_dir, 

166 endpoint=endpoint, 

167 access_key=access_key, 

168 secret_key=secret_key, 

169 password=password, 

170 salt=salt, 

171 bucket=bucket, 

172 key=key, 

173 decrypt=decrypt) 

174 if download_ok: 

175 click.echo('SHIELD-HIT12A downloaded from S3') 

176 else: 

177 click.echo('SHIELD-HIT12A download failed, trying to download demo version from shieldhit.org website') 

178 demo_download_ok = download_shieldhit_demo_version(destination_dir=destination_dir) 

179 if demo_download_ok: 

180 click.echo('SHIELD-HIT12A demo version downloaded from shieldhit.org website') 

181 else: 

182 click.echo('SHIELD-HIT12A demo version download failed') 

183 

184 

185# skipcq: PY-R1000 

186def download_topas_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str, key: str, 

187 version: str, geant4_bucket: str) -> bool: 

188 """Download TOPAS from S3 bucket""" 

189 s3_client = boto3.client("s3", 

190 aws_access_key_id=access_key, 

191 aws_secret_access_key=secret_key, 

192 endpoint_url=endpoint) 

193 

194 if not validate_connection_data(bucket, key, s3_client): 

195 return False 

196 

197 # Download TOPAS tar 

198 topas_temp_file = tempfile.NamedTemporaryFile() 

199 try: 

200 response = s3_client.list_object_versions( 

201 Bucket=bucket, 

202 Prefix=key, 

203 ) 

204 topas_file_downloaded = False 

205 for curr_version in response["Versions"]: 

206 version_id = curr_version["VersionId"] 

207 

208 tags = s3_client.get_object_tagging( 

209 Bucket=bucket, 

210 Key=key, 

211 VersionId=version_id, 

212 ) 

213 for tag in tags["TagSet"]: 

214 if tag["Key"] == "version" and tag["Value"] == version: 

215 click.echo(f"Downloading {key}, version {version} from {bucket} to {topas_temp_file.name}") 

216 s3_client.download_fileobj(Bucket=bucket, 

217 Key=key, 

218 Fileobj=topas_temp_file, 

219 ExtraArgs={"VersionId": version_id}) 

220 topas_file_downloaded = True 

221 if not topas_file_downloaded: 

222 click.echo(f"Could not find TOPAS version {version} in bucket {bucket}, file {key}", err=True) 

223 return False 

224 

225 except ClientError as e: 

226 click.echo("Failed to download TOPAS from S3 with error: ", e.response["Error"]["Message"]) 

227 return False 

228 

229 # Download GEANT4 tar files 

230 geant4_temp_files = [] 

231 

232 objects = s3_client.list_objects_v2(Bucket=geant4_bucket) 

233 

234 try: 

235 for obj in objects['Contents']: 

236 key = obj['Key'] 

237 response = s3_client.list_object_versions( 

238 Bucket=geant4_bucket, 

239 Prefix=key, 

240 ) 

241 for curr_version in response["Versions"]: 

242 version_id = curr_version["VersionId"] 

243 tags = s3_client.get_object_tagging( 

244 Bucket=geant4_bucket, 

245 Key=key, 

246 VersionId=version_id, 

247 ) 

248 for tag in tags["TagSet"]: 

249 if tag["Key"] == "topas_versions": 

250 topas_versions = tag["Value"].split(",") 

251 topas_versions = [version.strip() for version in topas_versions] 

252 if version in topas_versions: 

253 temp_file = tempfile.NamedTemporaryFile() 

254 click.echo(f"""Downloading {key} for TOPAS version {version} 

255 from {bucket} to {temp_file.name}""") 

256 s3_client.download_fileobj(Bucket=geant4_bucket, 

257 Key=key, 

258 Fileobj=temp_file, 

259 ExtraArgs={"VersionId": version_id}) 

260 geant4_temp_files.append(temp_file) 

261 

262 except ClientError as e: 

263 click.echo("Failed to download Geant4 data from S3 with error: ", e.response["Error"]["Message"]) 

264 return False 

265 

266 topas_temp_file.seek(0) 

267 topas_file_contents = tarfile.TarFile(fileobj=topas_temp_file) 

268 click.echo(f"Unpacking {topas_temp_file.name} to {download_dir}") 

269 topas_file_contents.extractall(path=download_dir) 

270 topas_extracted_path = download_dir / "topas" / "bin" / "topas" 

271 topas_extracted_path.chmod(0o700) 

272 click.echo(f"Installed TOPAS into {download_dir}") 

273 

274 geant4_files_path = download_dir / "geant4_files_path" 

275 if not geant4_files_path.exists(): 

276 try: 

277 geant4_files_path.mkdir() 

278 except OSError as e: 

279 click.echo(f"Could not create directory {geant4_files_path}: {e}", err=True) 

280 return False 

281 for file in geant4_temp_files: 

282 file.seek(0) 

283 file_contents = tarfile.TarFile(fileobj=file) 

284 click.echo(f"Unpacking {file.name} to {geant4_files_path}") 

285 file_contents.extractall(path=geant4_files_path) 

286 click.echo(f"Installed Geant4 files into {geant4_files_path}") 

287 return True 

288 

289 

290def extract_fluka_from_tar_gz(archive_path: Path, unpacking_directory: Path, destination_dir: Path) -> bool: 

291 """Extracts a single directory from a tar.gz archive""" 

292 with tarfile.open(archive_path, "r:gz") as tar: 

293 tar.extractall(path=unpacking_directory) 

294 content = list(unpacking_directory.iterdir()) 

295 if len(content) == 1: 

296 shutil.copytree(str(content[0]), str(destination_dir / 'fluka'), dirs_exist_ok=True) 

297 return True 

298 if len(content) > 1: 

299 shutil.copytree(str(unpacking_directory), str(destination_dir / 'fluka'), dirs_exist_ok=True) 

300 return True 

301 return False 

302 

303 

304def download_fluka_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str, 

305 password: str, salt: str, key: str) -> bool: 

306 """Download (and decrypt) Fluka from S3 bucket""" 

307 s3_client = boto3.client("s3", 

308 aws_access_key_id=access_key, 

309 aws_secret_access_key=secret_key, 

310 endpoint_url=endpoint) 

311 

312 if not validate_connection_data(bucket, key, s3_client): 

313 return False 

314 

315 with tempfile.TemporaryDirectory() as tmpdir_name: 

316 tmp_dir = Path(tmpdir_name).resolve() 

317 tmp_archive = tmp_dir / 'fluka.tgz' 

318 tmp_dir_path = tmp_dir / 'fluka' 

319 download_and_decrypt_status = download_file(key=key, 

320 bucket=bucket, 

321 s3_client=s3_client, 

322 decrypt=True, 

323 password=password, 

324 salt=salt, 

325 destination_file_path=tmp_archive) 

326 if not download_and_decrypt_status: 

327 return False 

328 download_and_decrypt_status = extract_fluka_from_tar_gz(archive_path=tmp_archive, 

329 unpacking_directory=tmp_dir_path, 

330 destination_dir=download_dir) 

331 

332 return download_and_decrypt_status 

333 

334 

335def upload_file_to_s3(bucket: str, 

336 file_path: Path, 

337 endpoint: str, 

338 access_key: str, 

339 secret_key: str, 

340 encrypt: bool = False, 

341 encryption_password: str = '', 

342 encryption_salt: str = '') -> bool: 

343 """Upload file to S3 bucket""" 

344 # Create S3 client 

345 s3_client = boto3.client( 

346 "s3", 

347 aws_access_key_id=access_key, 

348 aws_secret_access_key=secret_key, 

349 endpoint_url=endpoint, 

350 ) 

351 if not check_if_s3_connection_is_working(s3_client): 

352 click.echo("S3 connection failed", err=True) 

353 return False 

354 

355 # Check if bucket exists and create if not 

356 if bucket not in [bucket["Name"] for bucket in s3_client.list_buckets()["Buckets"]]: 

357 click.echo(f"Bucket {bucket} does not exist. Creating.") 

358 s3_client.create_bucket(Bucket=bucket) 

359 

360 # Encrypt file 

361 file_contents = file_path.read_bytes() 

362 if encrypt: 

363 click.echo(f"Encrypting file {file_path}") 

364 file_contents = encrypt_file(file_path, encryption_password, encryption_salt) 

365 try: 

366 # Upload encrypted file to S3 bucket 

367 click.echo(f"Uploading file {file_path}") 

368 s3_client.put_object(Body=file_contents, Bucket=bucket, Key=file_path.name) 

369 return True 

370 except ClientError as e: 

371 click.echo("Upload failed with error: ", e.response["Error"]["Message"]) 

372 return False 

373 

374 

375def encrypt_file(file_path: Path, password: str, salt: str) -> bytes: 

376 """Encrypts a file using Fernet""" 

377 encryption_key = derive_key(password, salt) 

378 # skipcq: PTC-W6004 

379 bytes_from_file = file_path.read_bytes() 

380 fernet = Fernet(encryption_key) 

381 encrypted = fernet.encrypt(bytes_from_file) 

382 return encrypted 

383 

384 

385def decrypt_file(file_path: Path, password: str, salt: str) -> bytes: 

386 """Decrypts a file using Fernet""" 

387 encryption_key = derive_key(password, salt) 

388 # skipcq: PTC-W6004 

389 bytes_from_file = file_path.read_bytes() 

390 fernet = Fernet(encryption_key) 

391 try: 

392 decrypted = fernet.decrypt(bytes_from_file) 

393 except cryptography.fernet.InvalidToken: 

394 click.echo("Decryption failed - invalid token (password+salt)", err=True) 

395 return b'' 

396 return decrypted 

397 

398 

399def validate_connection_data(bucket: str, key: str, s3_client) -> bool: 

400 """Validate S3 connection""" 

401 if not check_if_s3_connection_is_working(s3_client): 

402 click.echo("S3 connection failed", err=True) 

403 return False 

404 

405 # Check if bucket name is valid 

406 if not bucket: 

407 click.echo("Bucket name is empty", err=True) 

408 return False 

409 

410 # Check if key is valid 

411 if not key: 

412 click.echo("Key is empty", err=True) 

413 return False 

414 

415 # Check if bucket exists 

416 try: 

417 s3_client.head_bucket(Bucket=bucket) 

418 except ClientError as e: 

419 click.echo(f"Problem accessing bucket named {bucket}: {e}", err=True) 

420 return False 

421 

422 # Check if key exists 

423 try: 

424 s3_client.head_object(Bucket=bucket, Key=key) 

425 except ClientError as e: 

426 click.echo(f"Problem accessing key named {key} in bucket {bucket}: {e}", err=True) 

427 return False 

428 

429 return True 

430 

431 

432def download_file(key: str, 

433 bucket: str, 

434 s3_client, 

435 destination_file_path: Path, 

436 decrypt: bool = False, 

437 password: str = '', 

438 salt: str = ''): 

439 """Handle download with encryption""" 

440 try: 

441 with tempfile.NamedTemporaryFile() as temp_file: 

442 click.echo(f"Downloading {key} from {bucket} to {temp_file.name}") 

443 s3_client.download_fileobj(Bucket=bucket, Key=key, Fileobj=temp_file) 

444 

445 if decrypt: 

446 click.echo("Decrypting downloaded file") 

447 if not password or not salt: 

448 click.echo("Password or salt not set", err=True) 

449 return False 

450 bytes_from_decrypted_file = decrypt_file(file_path=Path(temp_file.name), password=password, salt=salt) 

451 if not bytes_from_decrypted_file: 

452 click.echo("Decryption failed", err=True) 

453 return False 

454 

455 Path(destination_file_path).parent.mkdir(parents=True, exist_ok=True) 

456 Path(destination_file_path).write_bytes(bytes_from_decrypted_file) 

457 else: 

458 click.echo(f"Copying {temp_file.name} to {destination_file_path}") 

459 shutil.copy2(temp_file.name, destination_file_path) 

460 except ClientError as e: 

461 click.echo(f"S3 download failed with client error: {e}", err=True) 

462 return False 

463 

464 destination_file_path.chmod(0o700) 

465 return True 

466 

467 

468def derive_key(password: str, salt: str) -> bytes: 

469 """Derives a key from the password and salt""" 

470 kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, salt=salt.encode(), iterations=480_000) 

471 key = urlsafe_b64encode(kdf.derive(password.encode())) 

472 return key