def main()

in tensorflow_managed_spot_training_checkpointing/generate_cifar10_tfrecords.py [0:0]


def main(data_dir):
    print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
    download_and_extract(data_dir)

    file_names = _get_file_names()
    input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
    for mode, files in file_names.items():
        input_files = [os.path.join(input_dir, f) for f in files]

        mode_dir = os.path.join(data_dir, mode)
        output_file = os.path.join(mode_dir, mode + '.tfrecords')
        if not os.path.exists(mode_dir):
            os.makedirs(mode_dir)
        try:
            os.remove(output_file)
        except OSError:
            pass

        # Convert to tf.train.Example and write the to TFRecords.
        convert_to_tfrecord(input_files, output_file)

    print('Done!')
    shutil.rmtree(os.path.join(data_dir, 'cifar-10-batches-py'))
    os.remove(os.path.join(data_dir, 'cifar-10-python.tar.gz'))  # Remove the original .tzr.gz files