datasets.py (270 lines of code) (raw):

import pickle import os import numpy as np import imageio try: from sklearn.cross_validation import train_test_split except ModuleNotFoundError: from sklearn.model_selection import train_test_split from mpi_utils import mpi_size, mpi_rank from janky_stuff import JankySubsampler mpisize = mpi_size() mpirank = mpi_rank() def get_dataset(name): return { 'cifar10': Cifar10, 'imagenet64': Imagenet64, 'imagenet32': Imagenet32, }[name] def tile_images(images, d1=4, d2=4, border=1): id1, id2, c = images[0].shape out = np.ones([d1 * id1 + border * (d1 + 1), d2 * id2 + border * (d2 + 1), c], dtype=np.uint8) out *= 255 if len(images) != d1 * d2: raise ValueError('Wrong num of images') for imgnum, im in enumerate(images): num_d1 = imgnum // d2 num_d2 = imgnum % d2 start_d1 = num_d1 * id1 + border * (num_d1 + 1) start_d2 = num_d2 * id2 + border * (num_d2 + 1) out[start_d1:start_d1 + id1, start_d2:start_d2 + id2, :] = im return out def iter_data_mpi(*args, n_batch, log, shuffle=False, iters=None, seed=None, split_by_rank=True): 'Take the tensors in *args and iterate through them across mpi ranks if split_by_rank, otherwise iter normally' if not args: raise ValueError size = args[0].shape[0] for idx in range(1, len(args)): if args[idx].shape[0] != size: raise ValueError(f'mismatch in arg {idx}, shape {args[idx].shape[0]} vs {size}') if seed: np.random.seed(seed) if shuffle: idxs = np.random.permutation(np.arange(size)) else: idxs = np.arange(size) ms = mpisize mr = mpirank if not split_by_rank: ms = 1 mr = 0 # Truncate the data if it does not divide evenly sequences_per_batch = ms * n_batch length = (idxs.size // sequences_per_batch) * sequences_per_batch if length != idxs.size: log('Truncating {}/{} sequences'.format(idxs.size - length, idxs.size)) idxs = idxs[:length] # Reshape starting indices to K*mpi_size*n_batch idxs = idxs.reshape([-1, ms, n_batch]) log(f'Number of minibatches in this dataset: {len(idxs)}') for mb_idx in range(len(idxs)): indices = idxs[mb_idx, mr] vals = [t[indices] for t in args] yield vals if iters and mb_idx > iters: break class ImageDataset(object): 'Non-jpeg images' def decode(self, samples, logname): H = self.H out_samples = self.samples_to_image(samples) n_examples = out_samples.shape[0] d2 = H.sample_grid_dim if d2 > n_examples: d2 = n_examples d1 = n_examples // d2 tiled_image = tile_images(out_samples, d1=d1, d2=d2) imname = f'{H.desc}-samples-{logname}.png' out_path = os.path.join(H.model_dir, imname) imageio.imwrite(out_path, tiled_image) self.logprint(f'Saved samples in file {out_path}') def initialize_image_embedding(self): w, h, c = self.embedding_sizes embedding = [] for i in range(w): for j in range(h): for k in range(c): embedding.append([i, j, k]) self.x_emb = np.array(embedding).T.reshape([1, 3, self.ctx]) def samples_to_image(self, samples): return samples.reshape(self.orig_shape) class JankySubsampledDataset(ImageDataset): def __init__(self, datasets, pmf, seed=42): assert len(pmf) == len(datasets) if seed is None: raise ValueError("seed can't be None") self.datasets = datasets self.pmf = pmf # Some basic sanity-checks. attrs = ( "orig_shape", "shape", "ctx", "num_embeddings", "embedding_sizes", "n_vocab", "x_emb", ) for attr in attrs: assert hasattr(self.ref, attr), f"{attr} is missing in the main dataset." ref_attr = getattr(self.ref, attr) setattr(self, attr, ref_attr) for oth in self.oth: assert hasattr(oth, attr), f"{attr} is missing in the auxiliary dataset" oth_attr = getattr(oth, attr) assert type(ref_attr) == type(oth_attr) if isinstance(ref_attr, np.ndarray): assert (ref_attr == oth_attr).all(), f"expected {attr} to be the same." else: assert ref_attr == oth_attr, f"expected {attr} to be the same." # Perform model selection and evaluation using the main dataset. attrs = ( "H", "logprint", "vaX", "vaY", "teX", "teY", "n_classes", "full_dataset_valid", "full_dataset_train", "iters_per_epoch", ) for attr in attrs: setattr(self, attr, getattr(self.ref, attr, None)) trX = [ds.trX for ds in datasets] auxX = [np.zeros_like(tr[:, 0:1]) + idx for idx, tr in enumerate(trX)] self.trX = JankySubsampler(trX, pmf, seed=seed) self.auxX = JankySubsampler(auxX, pmf, seed=seed) @property def ref(self): return self.datasets[0] @property def oth(self): return self.datasets[1:] class Imagenet64(ImageDataset): '''To download, if your data dir is /root/data: mkdir -p /root/data cd /root/data wget https://openaipublic.blob.core.windows.net/distribution-augmentation-assets/imagenet64-train.npy wget https://openaipublic.blob.core.windows.net/distribution-augmentation-assets/imagenet64-valid.npy ''' def __init__(self, H, logprint): self.logprint = logprint self.H = H # Whether the full dataset is loaded on each rank, or just its own partition self.full_dataset_train = True self.full_dataset_valid = True n_train = 1231149 self.n_batch = H.n_batch self.orig_shape = [-1, 64, 64, 3] self.orig_pixels = 64 * 64 * 3 self.num_embeddings = 3 self.n_vocab = 256 self.embedding_sizes = [64, 64, 3] self.iters_per_epoch = n_train // (mpisize * self.n_batch) tr = np.load('/root/data/imagenet64-train.npy', mmap_mode='r').reshape([-1, 12288]) self.trX = tr[:n_train] self.trY = None self.vaY = None self.teY = None self.vaX = tr[n_train:] self.n_classes = None self.teX = np.load('/root/data/imagenet64-valid.npy', mmap_mode='r').reshape([-1, 12288]) self.n_vocab = 256 self.ctx = 12288 self.shape = [-1, self.ctx] assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}' self.initialize_image_embedding() class Imagenet32(Imagenet64): '''To download, if your data dir is /root/data: mkdir -p /root/data cd /root/data wget https://openaipublic.blob.core.windows.net/distribution-augmentation-assets/imagenet32-train.npy wget https://openaipublic.blob.core.windows.net/distribution-augmentation-assets/imagenet32-valid.npy ''' def __init__(self, H, logprint): self.logprint = logprint self.H = H # 1281167 << dataset has this many examples # We will use 10k examples for dev n_train = 1281167 - 10000 self.full_dataset_train = True self.full_dataset_valid = True self.n_batch = H.n_batch self.orig_shape = [-1, 32, 32, 3] self.trY = None self.vaY = None self.teY = None self.n_classes = None self.orig_pixels = 32 * 32 * 3 self.num_embeddings = 3 self.n_vocab = 256 self.embedding_sizes = [32, 32, 3] self.iters_per_epoch = n_train // (mpisize * self.n_batch) # we are dumb and saved imagenet32 in 3x32x32, unlike ImageNet64, which we saved in transposed format, sorry about the inconsistency tr = np.load('/root/data/imagenet32-train.npy').reshape([-1, 3, 32, 32]).transpose( [0, 2, 3, 1]).reshape([-1, 3072]) self.trX = tr[:n_train] self.vaX = tr[n_train:] self.teX = np.load('/root/data/imagenet32-valid.npy').reshape([-1, 3, 32, 32]).transpose( [0, 2, 3, 1]).reshape([-1, 3072]) self.n_vocab = 256 self.ctx = 3072 self.shape = [-1, self.ctx] assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}' self.initialize_image_embedding() def flatten(outer): return [el for inner in outer for el in inner] def unpickle_cifar10(file): fo = open(file, 'rb') data = pickle.load(fo, encoding='bytes') fo.close() data = dict(zip([k.decode() for k in data.keys()], data.values())) return data def cifar10(data_dir, one_hot=True, test_size=None): test_size = test_size or 5000 tr_data = [unpickle_cifar10(os.path.join(data_dir, 'data_batch_%d' % i)) for i in range(1, 6)] trX = np.vstack(data['data'] for data in tr_data) trY = np.asarray(flatten([data['labels'] for data in tr_data])) te_data = unpickle_cifar10(os.path.join(data_dir, 'test_batch')) teX = np.asarray(te_data['data']) teY = np.asarray(te_data['labels']) trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).reshape([-1, 3072]) teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).reshape([-1, 3072]) trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=test_size, random_state=11172018) if one_hot: trY = np.eye(10, dtype=np.float32)[trY] vaY = np.eye(10, dtype=np.float32)[vaY] teY = np.eye(10, dtype=np.float32)[teY] else: trY = np.reshape(trY, [-1, 1]) vaY = np.reshape(vaY, [-1, 1]) teY = np.reshape(teY, [-1, 1]) return (trX, trY), (vaX, vaY), (teX, teY) class Cifar10(ImageDataset): def __init__(self, H, logprint): self.logprint = logprint self.H = H self.full_dataset_train = True self.full_dataset_valid = True # 5k examples for valid n_train = 45000 if H.datapoints: n_train = H.datapoints self.n_batch = H.n_batch self.iters_per_epoch = n_train // (mpisize * self.n_batch) self.orig_shape = [-1, 32, 32, 3] self.n_classes = 10 self.orig_pixels = 32 * 32 * 3 self.num_embeddings = 3 self.n_vocab = 256 self.embedding_sizes = [32, 32, 3] self.n_batch = H.n_batch self.iters_per_epoch = n_train // (mpisize * self.n_batch) (self.trX, self.trY), (self.vaX, self.vaY), (self.teX, self.teY) = cifar10('/root/data/cifar10/', one_hot=False, test_size=H.test_size) if H.datapoints: logprint(f'Only using {H.datapoints} examples') self.trX = self.trX[:n_train] self.trY = self.trY[:n_train] self.shape = [-1, 3072] self.ctx = 32 * 32 * 3 assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}' self.initialize_image_embedding() def preprocess(self, arr): arr = arr.reshape([-1, 3, 32, 32]) arr = arr.transpose([0, 2, 3, 1]) return arr.reshape([-1, 3072])