in src/sagemaker/session.py [0:0]
def download_data(self, path, bucket, key_prefix="", extra_args=None):
"""Download file or directory from S3.
Args:
path (str): Local path where the file or directory should be downloaded to.
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): Optional S3 object key name prefix.
extra_args (dict): Optional extra arguments that may be passed to the
download operation. Please refer to the ExtraArgs parameter in the boto3
documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
Returns:
list[str]: List of local paths of downloaded files
"""
# Initialize the S3 client.
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
directories = []
next_token = ""
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
# a "keys" list.
while next_token is not None:
request_parameters = base_parameters.copy()
if next_token != "":
request_parameters.update({"ContinuationToken": next_token})
response = s3.list_objects_v2(**request_parameters)
contents = response.get("Contents", None)
if not contents:
logger.info(
"Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
)
return []
# For each object, save its key or directory.
for s3_object in contents:
key: str = s3_object.get("Key")
obj_size = s3_object.get("Size")
if key.endswith("/") and int(obj_size) == 0:
directories.append(os.path.join(path, key))
else:
keys.append(key)
next_token = response.get("NextContinuationToken")
# For each object key, create the directory on the local machine if needed, and then
# download the file.
downloaded_paths = []
for dir_path in directories:
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
for key in keys:
tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
s3.download_file(
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
)
downloaded_paths.append(destination_path)
return downloaded_paths