def main()

in ec2-spot-sagemaker-managed-spot-training/sagemaker-custom-tensorflow/utils/generate_cifar10_tfrecords.py [0:0]


def main(data_dir):
    print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
    download_and_extract(data_dir)
    file_names = _get_file_names()
    input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
    for mode, files in file_names.items():
        input_files = [os.path.join(input_dir, f) for f in files]
        output_file = os.path.join(data_dir, mode + '.tfrecords')
        try:
            os.remove(output_file)
        except OSError:
            pass
        # Convert to tf.train.Example and write the to TFRecords.
        convert_to_tfrecord(input_files, output_file)
    print('Removing original files.')
    os.remove(os.path.join(data_dir, CIFAR_FILENAME))
    shutil.rmtree(input_dir)
    print('Done!')