def downloadfile()

in ec2-spot-tensorflow-checkpoint/tensorflow_checkpoint.py [0:0]


def downloadfile(bucket,max_to_keep):
    """Download checkpoint files from an S3 bucket
    :param bucket: Bucket to upload to
    :param max_to_keep: number of checkpoints
    """
    # select bucket
    my_bucket = s3.Bucket(bucket)
    # download file into current directory
    file_list = []
    for s3_object in my_bucket.objects.all():
        # Need to split s3_object.key into path and file name, else it will give error file not found.
        path, filename = os.path.split(s3_object.key)
        if filename == 'checkpoint':
            my_bucket.download_file(s3_object.key, s3_object.key)
            f = open(s3_object.key)
            line = f.readline()
            f.close()
            batch = int(re.findall('\d+',line)[0])
            for x in range(batch-max_to_keep,batch+1):
                file_list.append(x)
        else:
            if int(re.findall('\d+',filename)[0]) in file_list:
                my_bucket.download_file(s3_object.key, s3_object.key)