examples/vae/utils/mnist_cached.py (135 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import errno import os from functools import reduce import numpy as np import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from pyro.contrib.examples.util import get_data_directory # This file contains utilities for caching, transforming and splitting MNIST data # efficiently. By default, a PyTorch DataLoader will apply the transform every epoch # we avoid this by caching the data early on in MNISTCached class # transformations for MNIST data def fn_x_mnist(x, use_cuda): # normalize pixel values of the image to be in [0,1] instead of [0,255] xp = x * (1. / 255) # transform x to a linear tensor from bx * a1 * a2 * ... --> bs * A xp_1d_size = reduce(lambda a, b: a * b, xp.size()[1:]) xp = xp.view(-1, xp_1d_size) # send the data to GPU(s) if use_cuda: xp = xp.cuda() return xp def fn_y_mnist(y, use_cuda): yp = torch.zeros(y.size(0), 10) # send the data to GPU(s) if use_cuda: yp = yp.cuda() y = y.cuda() # transform the label y (integer between 0 and 9) to a one-hot yp = yp.scatter_(1, y.view(-1, 1), 1.0) return yp def get_ss_indices_per_class(y, sup_per_class): # number of indices to consider n_idxs = y.size()[0] # calculate the indices per class idxs_per_class = {j: [] for j in range(10)} # for each index identify the class and add the index to the right class for i in range(n_idxs): curr_y = y[i] for j in range(10): if curr_y[j] == 1: idxs_per_class[j].append(i) break idxs_sup = [] idxs_unsup = [] for j in range(10): np.random.shuffle(idxs_per_class[j]) idxs_sup.extend(idxs_per_class[j][:sup_per_class]) idxs_unsup.extend(idxs_per_class[j][sup_per_class:len(idxs_per_class[j])]) return idxs_sup, idxs_unsup def split_sup_unsup_valid(X, y, sup_num, validation_num=10000): """ helper function for splitting the data into supervised, un-supervised and validation parts :param X: images :param y: labels (digits) :param sup_num: what number of examples is supervised :param validation_num: what number of last examples to use for validation :return: splits of data by sup_num number of supervised examples """ # validation set is the last 10,000 examples X_valid = X[-validation_num:] y_valid = y[-validation_num:] X = X[0:-validation_num] y = y[0:-validation_num] assert sup_num % 10 == 0, "unable to have equal number of images per class" # number of supervised examples per class sup_per_class = int(sup_num / 10) idxs_sup, idxs_unsup = get_ss_indices_per_class(y, sup_per_class) X_sup = X[idxs_sup] y_sup = y[idxs_sup] X_unsup = X[idxs_unsup] y_unsup = y[idxs_unsup] return X_sup, y_sup, X_unsup, y_unsup, X_valid, y_valid def print_distribution_labels(y): """ helper function for printing the distribution of class labels in a dataset :param y: tensor of class labels given as one-hots :return: a dictionary of counts for each label from y """ counts = {j: 0 for j in range(10)} for i in range(y.size()[0]): for j in range(10): if y[i][j] == 1: counts[j] += 1 break print(counts) class MNISTCached(MNIST): """ a wrapper around MNIST to load and cache the transformed data once at the beginning of the inference """ # static class variables for caching training data train_data_size = 50000 train_data_sup, train_labels_sup = None, None train_data_unsup, train_labels_unsup = None, None validation_size = 10000 data_valid, labels_valid = None, None test_size = 10000 def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs): super().__init__(train=mode in ["sup", "unsup", "valid"], *args, **kwargs) # transformations on MNIST data (normalization and one-hot conversion for labels) def transform(x): return fn_x_mnist(x, use_cuda) def target_transform(y): return fn_y_mnist(y, use_cuda) self.mode = mode assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values" if mode in ["sup", "unsup", "valid"]: # transform the training data if transformations are provided if transform is not None: self.data = (transform(self.data.float())) if target_transform is not None: self.targets = (target_transform(self.targets)) if MNISTCached.train_data_sup is None: if sup_num is None: assert mode == "unsup" MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \ self.data, self.targets else: MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \ MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \ MNISTCached.data_valid, MNISTCached.labels_valid = \ split_sup_unsup_valid(self.data, self.targets, sup_num) if mode == "sup": self.data, self.targets = MNISTCached.train_data_sup, MNISTCached.train_labels_sup elif mode == "unsup": self.data = MNISTCached.train_data_unsup # making sure that the unsupervised labels are not available to inference self.targets = (torch.Tensor( MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan else: self.data, self.targets = MNISTCached.data_valid, MNISTCached.labels_valid else: # transform the testing data if transformations are provided if transform is not None: self.data = (transform(self.data.float())) if target_transform is not None: self.targets = (target_transform(self.targets)) def __getitem__(self, index): """ :param index: Index or slice object :returns tuple: (image, target) where target is index of the target class. """ if self.mode in ["sup", "unsup", "valid"]: img, target = self.data[index], self.targets[index] elif self.mode == "test": img, target = self.data[index], self.targets[index] else: assert False, "invalid mode: {}".format(self.mode) return img, target def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs): """ helper function for setting up pytorch data loaders for a semi-supervised dataset :param dataset: the data to use :param use_cuda: use GPU(s) for training :param batch_size: size of a batch of data to output when iterating over the data loaders :param sup_num: number of supervised data examples :param download: download the dataset (if it doesn't exist already) :param kwargs: other params for the pytorch data loader :return: three data loaders: (supervised data for training, un-supervised data for training, supervised data for testing) """ # instantiate the dataset as training/testing sets if root is None: root = get_data_directory(__file__) if 'num_workers' not in kwargs: kwargs = {'num_workers': 0, 'pin_memory': False} cached_data = {} loaders = {} for mode in ["unsup", "test", "sup", "valid"]: if sup_num is None and mode == "sup": # in this special case, we do not want "sup" and "valid" data loaders return loaders["unsup"], loaders["test"] cached_data[mode] = dataset(root=root, mode=mode, download=download, sup_num=sup_num, use_cuda=use_cuda) loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs) return loaders def mkdir_p(path): try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise EXAMPLE_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) DATA_DIR = os.path.join(EXAMPLE_DIR, 'data') RESULTS_DIR = os.path.join(EXAMPLE_DIR, 'results')