in ec2-spot-sagemaker-managed-spot-training/sagemaker-custom-tensorflow/utils/generate_cifar10_tfrecords.py [0:0]
def main(data_dir):
print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
download_and_extract(data_dir)
file_names = _get_file_names()
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
for mode, files in file_names.items():
input_files = [os.path.join(input_dir, f) for f in files]
output_file = os.path.join(data_dir, mode + '.tfrecords')
try:
os.remove(output_file)
except OSError:
pass
# Convert to tf.train.Example and write the to TFRecords.
convert_to_tfrecord(input_files, output_file)
print('Removing original files.')
os.remove(os.path.join(data_dir, CIFAR_FILENAME))
shutil.rmtree(input_dir)
print('Done!')