def extract_cifar10()

in research/pate_2017/input.py [0:0]


def extract_cifar10(local_url, data_dir):
  """Extracts CIFAR-10 and return numpy arrays with the different sets."""

  # These numpy dumps can be reloaded to avoid performing the pre-processing
  # if they exist in the working directory.
  # Changing the order of this list will ruin the indices below.
  preprocessed_files = ['/cifar10_train.npy',
                        '/cifar10_train_labels.npy',
                        '/cifar10_test.npy',
                        '/cifar10_test_labels.npy']

  all_preprocessed = True
  for file_name in preprocessed_files:
    if not tf.gfile.Exists(data_dir + file_name):
      all_preprocessed = False
      break

  if all_preprocessed:
    # Reload pre-processed training data from numpy dumps
    with tf.gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
      train_data = np.load(file_obj)
    with tf.gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
      train_labels = np.load(file_obj)

    # Reload pre-processed testing data from numpy dumps
    with tf.gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
      test_data = np.load(file_obj)
    with tf.gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
      test_labels = np.load(file_obj)

  else:
    # Do everything from scratch
    # Define lists of all files we should extract
    train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
    test_file = ['test_batch']
    cifar10_files = train_files + test_file

    # Check if all files have already been extracted
    need_to_unpack = False
    for file_name in cifar10_files:
      if not tf.gfile.Exists(file_name):
        need_to_unpack = True
        break

    # We have to unpack the archive
    if need_to_unpack:
      tarfile.open(local_url, 'r:gz').extractall(data_dir)

    # Load training images and labels
    images = []
    labels = []
    for train_file in train_files:
      # Construct filename
      filename = data_dir + '/cifar-10-batches-py/' + train_file

      # Unpickle dictionary and extract images and labels
      images_tmp, labels_tmp = unpickle_cifar_dic(filename)

      # Append to lists
      images.append(images_tmp)
      labels.append(labels_tmp)

    # Convert to numpy arrays and reshape in the expected format
    train_data = np.asarray(images, dtype=np.float32)
    train_data = train_data.reshape((50000, 3, 32, 32))
    train_data = np.swapaxes(train_data, 1, 3)
    train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)

    # Save so we don't have to do this again
    np.save(data_dir + preprocessed_files[0], train_data)
    np.save(data_dir + preprocessed_files[1], train_labels)

    # Construct filename for test file
    filename = data_dir + '/cifar-10-batches-py/' + test_file[0]

    # Load test images and labels
    test_data, test_images = unpickle_cifar_dic(filename)

    # Convert to numpy arrays and reshape in the expected format
    test_data = np.asarray(test_data, dtype=np.float32)
    test_data = test_data.reshape((10000, 3, 32, 32))
    test_data = np.swapaxes(test_data, 1, 3)
    test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)

    # Save so we don't have to do this again
    np.save(data_dir + preprocessed_files[2], test_data)
    np.save(data_dir + preprocessed_files[3], test_labels)

  return train_data, train_labels, test_data, test_labels