def wait_for_s3_object()

in Bin Packing/common/misc.py [0:0]


def wait_for_s3_object(s3_bucket, key, local_dir, local_prefix='', 
                       aws_account=None, aws_region=None, timeout=1200, limit=20,
                       fetch_only=None, training_job_name=None):
    """
    Keep polling s3 object until it is generated.
    Pulling down latest data to local directory with short key

    Arguments:
        s3_bucket (string): s3 bucket name
        key (string): key for s3 object
        local_dir (string): local directory path to save s3 object
        local_prefix (string): local prefix path append to the local directory
        aws_account (string): aws account of the s3 bucket
        aws_region (string): aws region where the repo is located
        timeout (int): how long to wait for the object to appear before giving up
        limit (int): maximum number of files to download
        fetch_only (lambda): a function to decide if this object should be fetched or not
        training_job_name (string): training job name to query job status

    Returns:
        A list of all downloaded files, as local filenames
    """
    session = boto3.Session()
    aws_account = aws_account or session.client("sts").get_caller_identity()['Account']
    aws_region = aws_region or session.region_name

    s3 = session.resource('s3')
    sagemaker = session.client('sagemaker')
    bucket = s3.Bucket(s3_bucket)
    objects = []

    print("Waiting for s3://%s/%s..." % (s3_bucket, key), end='', flush=True)
    start_time = time.time()
    cnt = 0
    while len(objects) == 0:
        objects = list(bucket.objects.filter(Prefix=key))
        if fetch_only:
            objects = list(filter(fetch_only, objects))
        if objects:
            continue
        print('.', end='', flush=True)
        time.sleep(5)
        cnt += 1
        if cnt % 80 == 0:
            print("")
        if time.time() > start_time + timeout:
            raise FileNotFoundError("S3 object s3://%s/%s never appeared after %d seconds" % (s3_bucket, key, timeout))
        if training_job_name:
            training_job_status = sagemaker.describe_training_job(TrainingJobName=training_job_name)['TrainingJobStatus']
            if training_job_status == 'Failed':
                raise RuntimeError("Training job {} failed while waiting for S3 object s3://{}/{}"
                                   .format(training_job_name, s3_bucket, key))

    print('\n', end='', flush=True)

    if len(objects) > limit:
        print("Only downloading %d of %d files" % (limit, len(objects)))
        objects = objects[-limit:]

    fetched_files = []
    for obj in objects:
        print("Downloading %s" % obj.key)
        local_path = os.path.join(local_dir, local_prefix, obj.key.split('/')[-1])
        obj.Object().download_file(local_path)
        fetched_files.append(local_path)

    return fetched_files