Coverage for yaptide/admin/simulator_storage.py: 18%
266 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-04 00:31 +0000
1import platform
2import shutil
3import tarfile
4import tempfile
5import zipfile
6from base64 import urlsafe_b64encode
7from enum import IntEnum, auto
8from pathlib import Path
10import boto3
11import click
12import cryptography
13import requests
14from botocore.exceptions import (ClientError, EndpointConnectionError, NoCredentialsError)
15from cryptography.fernet import Fernet
16from cryptography.hazmat.primitives import hashes
17from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
20class SimulatorType(IntEnum):
21 """Simulator types"""
23 shieldhit = auto()
24 fluka = auto()
25 topas = auto()
28def extract_shieldhit_from_tar_gz(archive_path: Path, unpacking_directory: Path, member_name: str,
29 destination_dir: Path):
30 """Extracts a single file from a tar.gz archive"""
31 with tarfile.open(archive_path, "r:gz") as tar:
32 # print all members
33 for member in tar.getmembers():
34 if Path(member.name).name == member_name and Path(member.name).parent.name == 'bin':
35 click.echo(f"Extracting {member.name}")
36 tar.extract(member, unpacking_directory)
37 # move to installation path
38 local_file_path = unpacking_directory / member.name
39 click.echo(f"Moving {local_file_path} to {destination_dir}")
40 shutil.move(local_file_path, destination_dir / member_name)
43def extract_shieldhit_from_zip(archive_path: Path, unpacking_dir: Path, member_name: str, destination_dir: Path):
44 """Extracts a single file from a zip archive"""
45 with zipfile.ZipFile(archive_path) as zip_handle:
46 # print all members
47 for member in zip_handle.infolist():
48 click.echo(f"Member: {member.filename}")
49 if Path(member.filename).name == member_name:
50 click.echo(f"Extracting {member.filename}")
51 zip_handle.extract(member, unpacking_dir)
52 # move to installation path
53 local_file_path = Path(unpacking_dir) / member.filename
54 destination_file_path = destination_dir / member_name
55 click.echo(f"Moving {local_file_path} to {destination_file_path}")
56 # move file from temporary directory to installation path using shutils
57 if not destination_file_path.exists():
58 shutil.move(local_file_path, destination_file_path)
61def download_shieldhit_demo_version(destination_dir: Path) -> bool:
62 """Download shieldhit demo version from shieldhit.org"""
63 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_x86_64_demo_gfortran_v1.1.0.tar.gz'
64 # check if working on Windows
65 if platform.system() == 'Windows':
66 demo_version_url = 'https://shieldhit.org/download/DEMO/shield_hit12a_win64_demo_v1.1.0.zip'
68 # create temporary directory and download
69 # Create a temporary file to store the downloaded binary data
70 with tempfile.TemporaryDirectory() as tmpdir_name:
71 click.echo(f"Downloading from {demo_version_url} to {tmpdir_name}")
72 headers = {'User-Agent': 'Mozilla/5.0 (Windows NT x.y; rv:10.0) Gecko/20100101 Firefox/10.0'}
73 response = requests.get(demo_version_url, headers=headers)
74 temp_file_archive = Path(tmpdir_name) / Path(demo_version_url).name
75 with open(temp_file_archive, 'wb') as file_handle:
76 file_handle.write(response.content)
77 click.echo(f"Saved to {temp_file_archive} with size {temp_file_archive.stat().st_size} bytes")
79 # extract
80 click.echo(f"Extracting {temp_file_archive} to {destination_dir}")
81 destination_dir.mkdir(parents=True, exist_ok=True)
82 if temp_file_archive.suffix == '.gz':
83 extract_shieldhit_from_tar_gz(temp_file_archive,
84 Path(tmpdir_name),
85 'shieldhit',
86 destination_dir=destination_dir)
87 elif temp_file_archive.suffix == '.zip':
88 extract_shieldhit_from_zip(temp_file_archive,
89 Path(tmpdir_name),
90 'shieldhit.exe',
91 destination_dir=destination_dir)
92 return True
95def check_if_s3_connection_is_working(s3_client: boto3.client) -> bool:
96 """Check if connection to S3 is possible"""
97 try:
98 s3_client.list_buckets()
99 except NoCredentialsError as e:
100 click.echo(f"No credentials found. Check your access key and secret key. {e}", err=True)
101 return False
102 except EndpointConnectionError as e:
103 click.echo(f"Could not connect to the specified endpoint. {e}", err=True)
104 return False
105 except ClientError as e:
106 click.echo(f"An error occurred while connecting to S3: {e.response['Error']['Message']}", err=True)
107 return False
108 return True
111def download_shieldhit_from_s3(
112 destination_dir: Path,
113 endpoint: str,
114 access_key: str,
115 secret_key: str,
116 password: str,
117 salt: str,
118 bucket: str,
119 key: str,
120 decrypt: bool = True,
121) -> bool:
122 """Download SHIELD-HIT12A from S3 bucket"""
123 s3_client = boto3.client("s3",
124 aws_access_key_id=access_key,
125 aws_secret_access_key=secret_key,
126 endpoint_url=endpoint)
128 if not validate_connection_data(bucket=bucket, key=key, s3_client=s3_client):
129 return False
131 if not destination_dir.exists():
132 destination_dir.mkdir(parents=True, exist_ok=True)
134 destination_file_path = destination_dir / 'shieldhit'
135 # append '.exe' to file name if working on Windows
136 if platform.system() == 'Windows':
137 destination_file_path = destination_dir / 'shieldhit.exe'
139 download_and_decrypt_status = download_file(key=key,
140 bucket=bucket,
141 s3_client=s3_client,
142 decrypt=decrypt,
143 password=password,
144 salt=salt,
145 destination_file_path=destination_file_path)
147 if not download_and_decrypt_status:
148 return False
150 return True
153def download_shieldhit_from_s3_or_from_website(
154 destination_dir: Path,
155 endpoint: str,
156 access_key: str,
157 secret_key: str,
158 password: str,
159 salt: str,
160 bucket: str,
161 key: str,
162 decrypt: bool = True,
163):
164 """Download SHIELD-HIT12A from S3 bucket, if not available download demo version from shieldhit.org website"""
165 download_ok = download_shieldhit_from_s3(destination_dir=destination_dir,
166 endpoint=endpoint,
167 access_key=access_key,
168 secret_key=secret_key,
169 password=password,
170 salt=salt,
171 bucket=bucket,
172 key=key,
173 decrypt=decrypt)
174 if download_ok:
175 click.echo('SHIELD-HIT12A downloaded from S3')
176 else:
177 click.echo('SHIELD-HIT12A download failed, trying to download demo version from shieldhit.org website')
178 demo_download_ok = download_shieldhit_demo_version(destination_dir=destination_dir)
179 if demo_download_ok:
180 click.echo('SHIELD-HIT12A demo version downloaded from shieldhit.org website')
181 else:
182 click.echo('SHIELD-HIT12A demo version download failed')
185# skipcq: PY-R1000
186def download_topas_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str, key: str,
187 version: str, geant4_bucket: str) -> bool:
188 """Download TOPAS from S3 bucket"""
189 s3_client = boto3.client("s3",
190 aws_access_key_id=access_key,
191 aws_secret_access_key=secret_key,
192 endpoint_url=endpoint)
194 if not validate_connection_data(bucket, key, s3_client):
195 return False
197 # Download TOPAS tar
198 topas_temp_file = tempfile.NamedTemporaryFile()
199 try:
200 response = s3_client.list_object_versions(
201 Bucket=bucket,
202 Prefix=key,
203 )
204 topas_file_downloaded = False
205 for curr_version in response["Versions"]:
206 version_id = curr_version["VersionId"]
208 tags = s3_client.get_object_tagging(
209 Bucket=bucket,
210 Key=key,
211 VersionId=version_id,
212 )
213 for tag in tags["TagSet"]:
214 if tag["Key"] == "version" and tag["Value"] == version:
215 click.echo(f"Downloading {key}, version {version} from {bucket} to {topas_temp_file.name}")
216 s3_client.download_fileobj(Bucket=bucket,
217 Key=key,
218 Fileobj=topas_temp_file,
219 ExtraArgs={"VersionId": version_id})
220 topas_file_downloaded = True
221 if not topas_file_downloaded:
222 click.echo(f"Could not find TOPAS version {version} in bucket {bucket}, file {key}", err=True)
223 return False
225 except ClientError as e:
226 click.echo("Failed to download TOPAS from S3 with error: ", e.response["Error"]["Message"])
227 return False
229 # Download GEANT4 tar files
230 geant4_temp_files = []
232 objects = s3_client.list_objects_v2(Bucket=geant4_bucket)
234 try:
235 for obj in objects['Contents']:
236 key = obj['Key']
237 response = s3_client.list_object_versions(
238 Bucket=geant4_bucket,
239 Prefix=key,
240 )
241 for curr_version in response["Versions"]:
242 version_id = curr_version["VersionId"]
243 tags = s3_client.get_object_tagging(
244 Bucket=geant4_bucket,
245 Key=key,
246 VersionId=version_id,
247 )
248 for tag in tags["TagSet"]:
249 if tag["Key"] == "topas_versions":
250 topas_versions = tag["Value"].split(",")
251 topas_versions = [version.strip() for version in topas_versions]
252 if version in topas_versions:
253 temp_file = tempfile.NamedTemporaryFile()
254 click.echo(f"""Downloading {key} for TOPAS version {version}
255 from {bucket} to {temp_file.name}""")
256 s3_client.download_fileobj(Bucket=geant4_bucket,
257 Key=key,
258 Fileobj=temp_file,
259 ExtraArgs={"VersionId": version_id})
260 geant4_temp_files.append(temp_file)
262 except ClientError as e:
263 click.echo("Failed to download Geant4 data from S3 with error: ", e.response["Error"]["Message"])
264 return False
266 topas_temp_file.seek(0)
267 topas_file_contents = tarfile.TarFile(fileobj=topas_temp_file)
268 click.echo(f"Unpacking {topas_temp_file.name} to {download_dir}")
269 topas_file_contents.extractall(path=download_dir)
270 topas_extracted_path = download_dir / "topas" / "bin" / "topas"
271 topas_extracted_path.chmod(0o700)
272 click.echo(f"Installed TOPAS into {download_dir}")
274 geant4_files_path = download_dir / "geant4_files_path"
275 if not geant4_files_path.exists():
276 try:
277 geant4_files_path.mkdir()
278 except OSError as e:
279 click.echo(f"Could not create directory {geant4_files_path}: {e}", err=True)
280 return False
281 for file in geant4_temp_files:
282 file.seek(0)
283 file_contents = tarfile.TarFile(fileobj=file)
284 click.echo(f"Unpacking {file.name} to {geant4_files_path}")
285 file_contents.extractall(path=geant4_files_path)
286 click.echo(f"Installed Geant4 files into {geant4_files_path}")
287 return True
290def extract_fluka_from_tar_gz(archive_path: Path, unpacking_directory: Path, destination_dir: Path) -> bool:
291 """Extracts a single directory from a tar.gz archive"""
292 with tarfile.open(archive_path, "r:gz") as tar:
293 tar.extractall(path=unpacking_directory)
294 content = list(unpacking_directory.iterdir())
295 if len(content) == 1:
296 shutil.copytree(str(content[0]), str(destination_dir / 'fluka'), dirs_exist_ok=True)
297 return True
298 if len(content) > 1:
299 shutil.copytree(str(unpacking_directory), str(destination_dir / 'fluka'), dirs_exist_ok=True)
300 return True
301 return False
304def download_fluka_from_s3(download_dir: Path, endpoint: str, access_key: str, secret_key: str, bucket: str,
305 password: str, salt: str, key: str) -> bool:
306 """Download (and decrypt) Fluka from S3 bucket"""
307 s3_client = boto3.client("s3",
308 aws_access_key_id=access_key,
309 aws_secret_access_key=secret_key,
310 endpoint_url=endpoint)
312 if not validate_connection_data(bucket, key, s3_client):
313 return False
315 with tempfile.TemporaryDirectory() as tmpdir_name:
316 tmp_dir = Path(tmpdir_name).resolve()
317 tmp_archive = tmp_dir / 'fluka.tgz'
318 tmp_dir_path = tmp_dir / 'fluka'
319 download_and_decrypt_status = download_file(key=key,
320 bucket=bucket,
321 s3_client=s3_client,
322 decrypt=True,
323 password=password,
324 salt=salt,
325 destination_file_path=tmp_archive)
326 if not download_and_decrypt_status:
327 return False
328 download_and_decrypt_status = extract_fluka_from_tar_gz(archive_path=tmp_archive,
329 unpacking_directory=tmp_dir_path,
330 destination_dir=download_dir)
332 return download_and_decrypt_status
335def upload_file_to_s3(bucket: str,
336 file_path: Path,
337 endpoint: str,
338 access_key: str,
339 secret_key: str,
340 encrypt: bool = False,
341 encryption_password: str = '',
342 encryption_salt: str = '') -> bool:
343 """Upload file to S3 bucket"""
344 # Create S3 client
345 s3_client = boto3.client(
346 "s3",
347 aws_access_key_id=access_key,
348 aws_secret_access_key=secret_key,
349 endpoint_url=endpoint,
350 )
351 if not check_if_s3_connection_is_working(s3_client):
352 click.echo("S3 connection failed", err=True)
353 return False
355 # Check if bucket exists and create if not
356 if bucket not in [bucket["Name"] for bucket in s3_client.list_buckets()["Buckets"]]:
357 click.echo(f"Bucket {bucket} does not exist. Creating.")
358 s3_client.create_bucket(Bucket=bucket)
360 # Encrypt file
361 file_contents = file_path.read_bytes()
362 if encrypt:
363 click.echo(f"Encrypting file {file_path}")
364 file_contents = encrypt_file(file_path, encryption_password, encryption_salt)
365 try:
366 # Upload encrypted file to S3 bucket
367 click.echo(f"Uploading file {file_path}")
368 s3_client.put_object(Body=file_contents, Bucket=bucket, Key=file_path.name)
369 return True
370 except ClientError as e:
371 click.echo("Upload failed with error: ", e.response["Error"]["Message"])
372 return False
375def encrypt_file(file_path: Path, password: str, salt: str) -> bytes:
376 """Encrypts a file using Fernet"""
377 encryption_key = derive_key(password, salt)
378 # skipcq: PTC-W6004
379 bytes_from_file = file_path.read_bytes()
380 fernet = Fernet(encryption_key)
381 encrypted = fernet.encrypt(bytes_from_file)
382 return encrypted
385def decrypt_file(file_path: Path, password: str, salt: str) -> bytes:
386 """Decrypts a file using Fernet"""
387 encryption_key = derive_key(password, salt)
388 # skipcq: PTC-W6004
389 bytes_from_file = file_path.read_bytes()
390 fernet = Fernet(encryption_key)
391 try:
392 decrypted = fernet.decrypt(bytes_from_file)
393 except cryptography.fernet.InvalidToken:
394 click.echo("Decryption failed - invalid token (password+salt)", err=True)
395 return b''
396 return decrypted
399def validate_connection_data(bucket: str, key: str, s3_client) -> bool:
400 """Validate S3 connection"""
401 if not check_if_s3_connection_is_working(s3_client):
402 click.echo("S3 connection failed", err=True)
403 return False
405 # Check if bucket name is valid
406 if not bucket:
407 click.echo("Bucket name is empty", err=True)
408 return False
410 # Check if key is valid
411 if not key:
412 click.echo("Key is empty", err=True)
413 return False
415 # Check if bucket exists
416 try:
417 s3_client.head_bucket(Bucket=bucket)
418 except ClientError as e:
419 click.echo(f"Problem accessing bucket named {bucket}: {e}", err=True)
420 return False
422 # Check if key exists
423 try:
424 s3_client.head_object(Bucket=bucket, Key=key)
425 except ClientError as e:
426 click.echo(f"Problem accessing key named {key} in bucket {bucket}: {e}", err=True)
427 return False
429 return True
432def download_file(key: str,
433 bucket: str,
434 s3_client,
435 destination_file_path: Path,
436 decrypt: bool = False,
437 password: str = '',
438 salt: str = ''):
439 """Handle download with encryption"""
440 try:
441 with tempfile.NamedTemporaryFile() as temp_file:
442 click.echo(f"Downloading {key} from {bucket} to {temp_file.name}")
443 s3_client.download_fileobj(Bucket=bucket, Key=key, Fileobj=temp_file)
445 if decrypt:
446 click.echo("Decrypting downloaded file")
447 if not password or not salt:
448 click.echo("Password or salt not set", err=True)
449 return False
450 bytes_from_decrypted_file = decrypt_file(file_path=Path(temp_file.name), password=password, salt=salt)
451 if not bytes_from_decrypted_file:
452 click.echo("Decryption failed", err=True)
453 return False
455 Path(destination_file_path).parent.mkdir(parents=True, exist_ok=True)
456 Path(destination_file_path).write_bytes(bytes_from_decrypted_file)
457 else:
458 click.echo(f"Copying {temp_file.name} to {destination_file_path}")
459 shutil.copy2(temp_file.name, destination_file_path)
460 except ClientError as e:
461 click.echo(f"S3 download failed with client error: {e}", err=True)
462 return False
464 destination_file_path.chmod(0o700)
465 return True
468def derive_key(password: str, salt: str) -> bytes:
469 """Derives a key from the password and salt"""
470 kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, salt=salt.encode(), iterations=480_000)
471 key = urlsafe_b64encode(kdf.derive(password.encode()))
472 return key