Coverage for yaptide/admin/simulator_storage.py: 31%

264 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-22 07:31 +0000

1import platform 

2import shutil 

3import tarfile 

4import tempfile 

5import zipfile 

6from base64 import urlsafe_b64encode 

7from enum import IntEnum, auto 

8from pathlib import Path 

9 

10import boto3 

11import click 

12import cryptography 

13import requests 

14from botocore.exceptions import (ClientError, EndpointConnectionError, NoCredentialsError) 

15from cryptography.fernet import Fernet 

16from cryptography.hazmat.primitives import hashes 

17from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC 

18 

19 

20class SimulatorType(IntEnum): 

21 """Simulator types""" 

22 

23 shieldhit = auto() 

24 fluka = auto() 

25 topas = auto() 

26 

27 

28def extract_shieldhit_from_tar_gz(archive_path: Path, unpacking_directory: Path, member_name: str, 

29 destination_dir: Path): 

30 """Extracts a single file from a tar.gz archive""" 

31 with tarfile.open(archive_path, "r:gz") as tar: 

32 # print all members 

33 for member in tar.getmembers(): 

34 if Path(member.name).name == member_name and Path(member.name).parent.name == 'bin': 

35 click.echo(f"Extracting {member.name}") 

36 tar.extract(member, unpacking_directory) 

37 # move to installation path 

38 local_file_path = unpacking_directory / member.name 

39 click.echo(f"Moving {local_file_path} to {destination_dir}") 

40 shutil.move(local_file_path, destination_dir / member_name) 

41 

42 

43def extract_shieldhit_from_zip(archive_path: Path, unpacking_dir: Path, member_name: str, destination_dir: Path): 

44 """Extracts a single file from a zip archive""" 

45 with zipfile.ZipFile(archive_path) as zip_handle: 

46 # print all members 

47 for member in zip_handle.infolist(): 

48 click.echo(f"Member: {member.filename}") 

49 if Path(member.filename).name == member_name: 

50 click.echo(f"Extracting {member.filename}") 

51 zip_handle.extract(member, unpacking_dir) 

52 # move to installation path 

53 local_file_path = Path(unpacking_dir) / member.filename 

54 destination_file_path = destination_dir / member_name 

55 click.echo(f"Moving {local_file_path} to {destination_file_path}") 

56 # move file from temporary directory to installation path using shutils 

57 if not destination_file_path.exists(): 

58 shutil.move(local_file_path, destination_file_path) 

59 

60 

61def download_shieldhit_demo_version(destination_dir: Path) -> bool: 

62 """Download shieldhit demo version from shieldhit.org""" 

63 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_x86_64_demo_gfortran_v1.1.0.tar.gz' 

64 # check if working on Windows 

65 if platform.system() == 'Windows': 

66 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_win64_demo_v1.1.0.zip' 

67 

68 # create temporary directory and download 

69 # Create a temporary file to store the downloaded binary data 

70 with tempfile.TemporaryDirectory() as tmpdir_name: 

71 click.echo(f"Downloading from {demo_version_url} to {tmpdir_name}") 

72 headers = {'User-Agent': 'Mozilla/5.0 (Windows NT x.y; rv:10.0) Gecko/20100101 Firefox/10.0'} 

73 response = requests.get(demo_version_url, headers=headers) 

74 temp_file_archive = Path(tmpdir_name) / Path(demo_version_url).name 

75 with open(temp_file_archive, 'wb') as file_handle: 

76 file_handle.write(response.content) 

77 click.echo(f"Saved to {temp_file_archive} with size {temp_file_archive.stat().st_size} bytes") 

78 

79 # extract 

80 click.echo(f"Extracting {temp_file_archive} to {destination_dir}") 

81 destination_dir.mkdir(parents=True, exist_ok=True) 

82 if temp_file_archive.suffix == '.gz': 

83 extract_shieldhit_from_tar_gz(temp_file_archive, 

84 Path(tmpdir_name), 

85 'shieldhit', 

86 destination_dir=destination_dir) 

87 elif temp_file_archive.suffix == '.zip': 

88 extract_shieldhit_from_zip(temp_file_archive, 

89 Path(tmpdir_name), 

90 'shieldhit.exe', 

91 destination_dir=destination_dir) 

92 return True 

93 

94 

95def check_if_s3_connection_is_working(s3_client: boto3.client) -> bool: 

96 """Check if connection to S3 is possible""" 

97 try: 

98 s3_client.list_buckets() 

99 except NoCredentialsError as e: 

100 click.echo(f"No credentials found. Check your access key and secret key. {e}", err=True) 

101 return False 

102 except EndpointConnectionError as e: 

103 click.echo(f"Could not connect to the specified endpoint. {e}", err=True) 

104 return False 

105 except ClientError as e: 

106 click.echo(f"An error occurred while connecting to S3: {e.response['Error']['Message']}", err=True) 

107 return False 

108 return True 

109 

110 

111def download_shieldhit_from_s3( 

112 destination_dir: Path, 

113 endpoint: str, 

114 access_key: str, 

115 secret_key: str, 

116 password: str, 

117 salt: str, 

118 bucket: str, 

119 key: str, 

120 decrypt: bool = True, 

121) -> bool: 

122 """Download SHIELD-HIT12A from S3 bucket""" 

123 s3_client = boto3.client("s3", 

124 aws_access_key_id=access_key, 

125 aws_secret_access_key=secret_key, 

126 endpoint_url=endpoint) 

127 

128 if not validate_connection_data(bucket=bucket, key=key, s3_client=s3_client): 

129 return False 

130 

131 destination_file_path = destination_dir / 'shieldhit' 

132 # append '.exe' to file name if working on Windows 

133 if platform.system() == 'Windows': 

134 destination_file_path = destination_dir / 'shieldhit.exe' 

135 

136 download_and_decrypt_status = download_file(key=key, 

137 bucket=bucket, 

138 s3_client=s3_client, 

139 decrypt=decrypt, 

140 password=password, 

141 salt=salt, 

142 destination_file_path=destination_file_path) 

143 

144 if not download_and_decrypt_status: 

145 return False 

146 

147 return True 

148 

149 

150def download_shieldhit_from_s3_or_from_website( 

151 destination_dir: Path, 

152 endpoint: str, 

153 access_key: str, 

154 secret_key: str, 

155 password: str, 

156 salt: str, 

157 bucket: str, 

158 key: str, 

159 decrypt: bool = True, 

160): 

161 """Download SHIELD-HIT12A from S3 bucket, if not available download demo version from shieldhit.org website""" 

162 download_ok = download_shieldhit_from_s3(destination_dir=destination_dir, 

163 endpoint=endpoint, 

164 access_key=access_key, 

165 secret_key=secret_key, 

166 password=password, 

167 salt=salt, 

168 bucket=bucket, 

169 key=key, 

170 decrypt=decrypt) 

171 if download_ok: 

172 click.echo('SHIELD-HIT12A downloaded from S3') 

173 else: 

174 click.echo('SHIELD-HIT12A download failed, trying to download demo version from shieldhit.org website') 

175 demo_download_ok = download_shieldhit_demo_version(destination_dir=destination_dir) 

176 if demo_download_ok: 

177 click.echo('SHIELD-HIT12A demo version downloaded from shieldhit.org website') 

178 else: 

179 click.echo('SHIELD-HIT12A demo version download failed') 

180 

181 

182# skipcq: PY-R1000 

183def download_topas_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str, key: str, 

184 version: str, geant4_bucket: str) -> bool: 

185 """Download TOPAS from S3 bucket""" 

186 s3_client = boto3.client("s3", 

187 aws_access_key_id=access_key, 

188 aws_secret_access_key=secret_key, 

189 endpoint_url=endpoint) 

190 

191 if not validate_connection_data(bucket, key, s3_client): 

192 return False 

193 

194 # Download TOPAS tar 

195 topas_temp_file = tempfile.NamedTemporaryFile() 

196 try: 

197 response = s3_client.list_object_versions( 

198 Bucket=bucket, 

199 Prefix=key, 

200 ) 

201 topas_file_downloaded = False 

202 for curr_version in response["Versions"]: 

203 version_id = curr_version["VersionId"] 

204 

205 tags = s3_client.get_object_tagging( 

206 Bucket=bucket, 

207 Key=key, 

208 VersionId=version_id, 

209 ) 

210 for tag in tags["TagSet"]: 

211 if tag["Key"] == "version" and tag["Value"] == version: 

212 click.echo(f"Downloading {key}, version {version} from {bucket} to {topas_temp_file.name}") 

213 s3_client.download_fileobj(Bucket=bucket, 

214 Key=key, 

215 Fileobj=topas_temp_file, 

216 ExtraArgs={"VersionId": version_id}) 

217 topas_file_downloaded = True 

218 if not topas_file_downloaded: 

219 click.echo(f"Could not find TOPAS version {version} in bucket {bucket}, file {key}", err=True) 

220 return False 

221 

222 except ClientError as e: 

223 click.echo("Failed to download TOPAS from S3 with error: ", e.response["Error"]["Message"]) 

224 return False 

225 

226 # Download GEANT4 tar files 

227 geant4_temp_files = [] 

228 

229 objects = s3_client.list_objects_v2(Bucket=geant4_bucket) 

230 

231 try: 

232 for obj in objects['Contents']: 

233 key = obj['Key'] 

234 response = s3_client.list_object_versions( 

235 Bucket=geant4_bucket, 

236 Prefix=key, 

237 ) 

238 for curr_version in response["Versions"]: 

239 version_id = curr_version["VersionId"] 

240 tags = s3_client.get_object_tagging( 

241 Bucket=geant4_bucket, 

242 Key=key, 

243 VersionId=version_id, 

244 ) 

245 for tag in tags["TagSet"]: 

246 if tag["Key"] == "topas_versions": 

247 topas_versions = tag["Value"].split(",") 

248 topas_versions = [version.strip() for version in topas_versions] 

249 if version in topas_versions: 

250 temp_file = tempfile.NamedTemporaryFile() 

251 click.echo(f"""Downloading {key} for TOPAS version {version} 

252 from {bucket} to {temp_file.name}""") 

253 s3_client.download_fileobj(Bucket=geant4_bucket, 

254 Key=key, 

255 Fileobj=temp_file, 

256 ExtraArgs={"VersionId": version_id}) 

257 geant4_temp_files.append(temp_file) 

258 

259 except ClientError as e: 

260 click.echo("Failed to download Geant4 data from S3 with error: ", e.response["Error"]["Message"]) 

261 return False 

262 

263 topas_temp_file.seek(0) 

264 topas_file_contents = tarfile.TarFile(fileobj=topas_temp_file) 

265 click.echo(f"Unpacking {topas_temp_file.name} to {download_dir}") 

266 topas_file_contents.extractall(path=download_dir) 

267 topas_extracted_path = download_dir / "topas" / "bin" / "topas" 

268 topas_extracted_path.chmod(0o700) 

269 click.echo(f"Installed TOPAS into {download_dir}") 

270 

271 geant4_files_path = download_dir / "geant4_files_path" 

272 if not geant4_files_path.exists(): 

273 try: 

274 geant4_files_path.mkdir() 

275 except OSError as e: 

276 click.echo(f"Could not create directory {geant4_files_path}: {e}", err=True) 

277 return False 

278 for file in geant4_temp_files: 

279 file.seek(0) 

280 file_contents = tarfile.TarFile(fileobj=file) 

281 click.echo(f"Unpacking {file.name} to {geant4_files_path}") 

282 file_contents.extractall(path=geant4_files_path) 

283 click.echo(f"Installed Geant4 files into {geant4_files_path}") 

284 return True 

285 

286 

287def extract_fluka_from_tar_gz(archive_path: Path, unpacking_directory: Path, destination_dir: Path) -> bool: 

288 """Extracts a single directory from a tar.gz archive""" 

289 with tarfile.open(archive_path, "r:gz") as tar: 

290 tar.extractall(path=unpacking_directory) 

291 content = list(unpacking_directory.iterdir()) 

292 if len(content) == 1: 

293 shutil.copytree(str(content[0]), str(destination_dir / 'fluka'), dirs_exist_ok=True) 

294 return True 

295 if len(content) > 1: 

296 shutil.copytree(str(unpacking_directory), str(destination_dir / 'fluka'), dirs_exist_ok=True) 

297 return True 

298 return False 

299 

300 

301def download_fluka_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str, 

302 password: str, salt: str, key: str) -> bool: 

303 """Download (and decrypt) Fluka from S3 bucket""" 

304 s3_client = boto3.client("s3", 

305 aws_access_key_id=access_key, 

306 aws_secret_access_key=secret_key, 

307 endpoint_url=endpoint) 

308 

309 if not validate_connection_data(bucket, key, s3_client): 

310 return False 

311 

312 with tempfile.TemporaryDirectory() as tmpdir_name: 

313 tmp_dir = Path(tmpdir_name).resolve() 

314 tmp_archive = tmp_dir / 'fluka.tgz' 

315 tmp_dir_path = tmp_dir / 'fluka' 

316 download_and_decrypt_status = download_file(key=key, 

317 bucket=bucket, 

318 s3_client=s3_client, 

319 decrypt=True, 

320 password=password, 

321 salt=salt, 

322 destination_file_path=tmp_archive) 

323 if not download_and_decrypt_status: 

324 return False 

325 download_and_decrypt_status = extract_fluka_from_tar_gz(archive_path=tmp_archive, 

326 unpacking_directory=tmp_dir_path, 

327 destination_dir=download_dir) 

328 

329 return download_and_decrypt_status 

330 

331 

332def upload_file_to_s3(bucket: str, 

333 file_path: Path, 

334 endpoint: str, 

335 access_key: str, 

336 secret_key: str, 

337 encrypt: bool = False, 

338 encryption_password: str = '', 

339 encryption_salt: str = '') -> bool: 

340 """Upload file to S3 bucket""" 

341 # Create S3 client 

342 s3_client = boto3.client( 

343 "s3", 

344 aws_access_key_id=access_key, 

345 aws_secret_access_key=secret_key, 

346 endpoint_url=endpoint, 

347 ) 

348 if not check_if_s3_connection_is_working(s3_client): 

349 click.echo("S3 connection failed", err=True) 

350 return False 

351 

352 # Check if bucket exists and create if not 

353 if bucket not in [bucket["Name"] for bucket in s3_client.list_buckets()["Buckets"]]: 

354 click.echo(f"Bucket {bucket} does not exist. Creating.") 

355 s3_client.create_bucket(Bucket=bucket) 

356 

357 # Encrypt file 

358 file_contents = file_path.read_bytes() 

359 if encrypt: 

360 click.echo(f"Encrypting file {file_path}") 

361 file_contents = encrypt_file(file_path, encryption_password, encryption_salt) 

362 try: 

363 # Upload encrypted file to S3 bucket 

364 click.echo(f"Uploading file {file_path}") 

365 s3_client.put_object(Body=file_contents, Bucket=bucket, Key=file_path.name) 

366 return True 

367 except ClientError as e: 

368 click.echo("Upload failed with error: ", e.response["Error"]["Message"]) 

369 return False 

370 

371 

372def encrypt_file(file_path: Path, password: str, salt: str) -> bytes: 

373 """Encrypts a file using Fernet""" 

374 encryption_key = derive_key(password, salt) 

375 # skipcq: PTC-W6004 

376 bytes_from_file = file_path.read_bytes() 

377 fernet = Fernet(encryption_key) 

378 encrypted = fernet.encrypt(bytes_from_file) 

379 return encrypted 

380 

381 

382def decrypt_file(file_path: Path, password: str, salt: str) -> bytes: 

383 """Decrypts a file using Fernet""" 

384 encryption_key = derive_key(password, salt) 

385 # skipcq: PTC-W6004 

386 bytes_from_file = file_path.read_bytes() 

387 fernet = Fernet(encryption_key) 

388 try: 

389 decrypted = fernet.decrypt(bytes_from_file) 

390 except cryptography.fernet.InvalidToken: 

391 click.echo("Decryption failed - invalid token (password+salt)", err=True) 

392 return b'' 

393 return decrypted 

394 

395 

396def validate_connection_data(bucket: str, key: str, s3_client) -> bool: 

397 """Validate S3 connection""" 

398 if not check_if_s3_connection_is_working(s3_client): 

399 click.echo("S3 connection failed", err=True) 

400 return False 

401 

402 # Check if bucket name is valid 

403 if not bucket: 

404 click.echo("Bucket name is empty", err=True) 

405 return False 

406 

407 # Check if key is valid 

408 if not key: 

409 click.echo("Key is empty", err=True) 

410 return False 

411 

412 # Check if bucket exists 

413 try: 

414 s3_client.head_bucket(Bucket=bucket) 

415 except ClientError as e: 

416 click.echo(f"Problem accessing bucket named {bucket}: {e}", err=True) 

417 return False 

418 

419 # Check if key exists 

420 try: 

421 s3_client.head_object(Bucket=bucket, Key=key) 

422 except ClientError as e: 

423 click.echo(f"Problem accessing key named {key} in bucket {bucket}: {e}", err=True) 

424 return False 

425 

426 return True 

427 

428 

429def download_file(key: str, 

430 bucket: str, 

431 s3_client, 

432 destination_file_path: Path, 

433 decrypt: bool = False, 

434 password: str = '', 

435 salt: str = ''): 

436 """Handle download with encryption""" 

437 try: 

438 with tempfile.NamedTemporaryFile() as temp_file: 

439 click.echo(f"Downloading {key} from {bucket} to {temp_file.name}") 

440 s3_client.download_fileobj(Bucket=bucket, Key=key, Fileobj=temp_file) 

441 

442 if decrypt: 

443 click.echo("Decrypting downloaded file") 

444 if not password or not salt: 

445 click.echo("Password or salt not set", err=True) 

446 return False 

447 bytes_from_decrypted_file = decrypt_file(file_path=Path(temp_file.name), password=password, salt=salt) 

448 if not bytes_from_decrypted_file: 

449 click.echo("Decryption failed", err=True) 

450 return False 

451 

452 Path(destination_file_path).parent.mkdir(parents=True, exist_ok=True) 

453 Path(destination_file_path).write_bytes(bytes_from_decrypted_file) 

454 else: 

455 click.echo(f"Copying {temp_file.name} to {destination_file_path}") 

456 shutil.copy2(temp_file.name, destination_file_path) 

457 except ClientError as e: 

458 click.echo(f"S3 download failed with client error: {e}", err=True) 

459 return False 

460 

461 destination_file_path.chmod(0o700) 

462 return True 

463 

464 

465def derive_key(password: str, salt: str) -> bytes: 

466 """Derives a key from the password and salt""" 

467 kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, salt=salt.encode(), iterations=480_000) 

468 key = urlsafe_b64encode(kdf.derive(password.encode())) 

469 return key