def main()

in tf-horovod-inference-pipeline/generate_cifar10_tfrecords.py [0:0]


def main(data_dir):
  print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
  download_and_extract(data_dir)
  file_names = _get_file_names()
  input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
  for mode, files in file_names.items():
    input_files = [os.path.join(input_dir, f) for f in files]
    output_file = os.path.join(data_dir+'/'+mode, mode + '.tfrecords')
    if not os.path.exists(data_dir+'/'+mode):
        os.makedirs(data_dir+'/'+mode)
    try:
      os.remove(output_file)
    except OSError:
      pass
    # Convert to tf.train.Example and write the to TFRecords.
    convert_to_tfrecord(input_files, output_file)
  print('Done!')
  import shutil
  shutil.rmtree(data_dir+'/cifar-10-batches-py')
  os.remove(data_dir+'/cifar-10-python.tar.gz')