data.py (136 lines of code) (raw):

import numpy as np import pickle import os import torch from torch.utils.data import TensorDataset from torchvision.datasets import ImageFolder import torchvision.transforms as transforms from sklearn.model_selection import train_test_split def set_up_data(H): shift_loss = -127.5 scale_loss = 1. / 127.5 if H.dataset == 'imagenet32': trX, vaX, teX = imagenet32(H.data_root) H.image_size = 32 H.image_channels = 3 shift = -116.2373 scale = 1. / 69.37404 elif H.dataset == 'imagenet64': trX, vaX, teX = imagenet64(H.data_root) H.image_size = 64 H.image_channels = 3 shift = -115.92961967 scale = 1. / 69.37404 elif H.dataset == 'ffhq_256': trX, vaX, teX = ffhq256(H.data_root) H.image_size = 256 H.image_channels = 3 shift = -112.8666757481 scale = 1. / 69.84780273 elif H.dataset == 'ffhq_1024': trX, vaX, teX = ffhq1024(H.data_root) H.image_size = 1024 H.image_channels = 3 shift = -0.4387 scale = 1.0 / 0.2743 shift_loss = -0.5 scale_loss = 2.0 elif H.dataset == 'cifar10': (trX, _), (vaX, _), (teX, _) = cifar10(H.data_root, one_hot=False) H.image_size = 32 H.image_channels = 3 shift = -120.63838 scale = 1. / 64.16736 else: raise ValueError('unknown dataset: ', H.dataset) do_low_bit = H.dataset in ['ffhq_256'] if H.test_eval: print('DOING TEST') eval_dataset = teX else: eval_dataset = vaX shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1) scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1) shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1) scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1) if H.dataset == 'ffhq_1024': train_data = ImageFolder(trX, transforms.ToTensor()) valid_data = ImageFolder(eval_dataset, transforms.ToTensor()) untranspose = True else: train_data = TensorDataset(torch.as_tensor(trX)) valid_data = TensorDataset(torch.as_tensor(eval_dataset)) untranspose = False def preprocess_func(x): nonlocal shift nonlocal scale nonlocal shift_loss nonlocal scale_loss nonlocal do_low_bit nonlocal untranspose 'takes in a data example and returns the preprocessed input' 'as well as the input processed for the loss' if untranspose: x[0] = x[0].permute(0, 2, 3, 1) inp = x[0].cuda(non_blocking=True).float() out = inp.clone() inp.add_(shift).mul_(scale) if do_low_bit: # 5 bits of precision out.mul_(1. / 8.).floor_().mul_(8.) out.add_(shift_loss).mul_(scale_loss) return inp, out return H, train_data, valid_data, preprocess_func def mkdir_p(path): os.makedirs(path, exist_ok=True) def flatten(outer): return [el for inner in outer for el in inner] def unpickle_cifar10(file): fo = open(file, 'rb') data = pickle.load(fo, encoding='bytes') fo.close() data = dict(zip([k.decode() for k in data.keys()], data.values())) return data def imagenet32(data_root): trX = np.load(os.path.join(data_root, 'imagenet32-train.npy'), mmap_mode='r') np.random.seed(42) tr_va_split_indices = np.random.permutation(trX.shape[0]) train = trX[tr_va_split_indices[:-5000]] valid = trX[tr_va_split_indices[-5000:]] test = np.load(os.path.join(data_root, 'imagenet32-valid.npy'), mmap_mode='r') return train, valid, test def imagenet64(data_root): trX = np.load(os.path.join(data_root, 'imagenet64-train.npy'), mmap_mode='r') np.random.seed(42) tr_va_split_indices = np.random.permutation(trX.shape[0]) train = trX[tr_va_split_indices[:-5000]] valid = trX[tr_va_split_indices[-5000:]] test = np.load(os.path.join(data_root, 'imagenet64-valid.npy'), mmap_mode='r') # this is test. return train, valid, test def ffhq1024(data_root): # we did not significantly tune hyperparameters on ffhq-1024, and so simply evaluate on the test set return os.path.join(data_root, 'ffhq1024/train'), os.path.join(data_root, 'ffhq1024/valid'), os.path.join(data_root, 'ffhq1024/valid') def ffhq256(data_root): trX = np.load(os.path.join(data_root, 'ffhq-256.npy'), mmap_mode='r') np.random.seed(5) tr_va_split_indices = np.random.permutation(trX.shape[0]) train = trX[tr_va_split_indices[:-7000]] valid = trX[tr_va_split_indices[-7000:]] # we did not significantly tune hyperparameters on ffhq-256, and so simply evaluate on the test set return train, valid, valid def cifar10(data_root, one_hot=True): tr_data = [unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'data_batch_%d' % i)) for i in range(1, 6)] trX = np.vstack(data['data'] for data in tr_data) trY = np.asarray(flatten([data['labels'] for data in tr_data])) te_data = unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'test_batch')) teX = np.asarray(te_data['data']) teY = np.asarray(te_data['labels']) trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=5000, random_state=11172018) if one_hot: trY = np.eye(10, dtype=np.float32)[trY] vaY = np.eye(10, dtype=np.float32)[vaY] teY = np.eye(10, dtype=np.float32)[teY] else: trY = np.reshape(trY, [-1, 1]) vaY = np.reshape(vaY, [-1, 1]) teY = np.reshape(teY, [-1, 1]) return (trX, trY), (vaX, vaY), (teX, teY)