in tensorflow_managed_spot_training_checkpointing/generate_cifar10_tfrecords.py [0:0]
def main(data_dir):
print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
download_and_extract(data_dir)
file_names = _get_file_names()
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
for mode, files in file_names.items():
input_files = [os.path.join(input_dir, f) for f in files]
mode_dir = os.path.join(data_dir, mode)
output_file = os.path.join(mode_dir, mode + '.tfrecords')
if not os.path.exists(mode_dir):
os.makedirs(mode_dir)
try:
os.remove(output_file)
except OSError:
pass
# Convert to tf.train.Example and write the to TFRecords.
convert_to_tfrecord(input_files, output_file)
print('Done!')
shutil.rmtree(os.path.join(data_dir, 'cifar-10-batches-py'))
os.remove(os.path.join(data_dir, 'cifar-10-python.tar.gz')) # Remove the original .tzr.gz files