in datasets.py [0:0]
def __init__(self, H, logprint):
self.logprint = logprint
self.H = H
self.full_dataset_train = True
self.full_dataset_valid = True
# 5k examples for valid
n_train = 45000
if H.datapoints:
n_train = H.datapoints
self.n_batch = H.n_batch
self.iters_per_epoch = n_train // (mpisize * self.n_batch)
self.orig_shape = [-1, 32, 32, 3]
self.n_classes = 10
self.orig_pixels = 32 * 32 * 3
self.num_embeddings = 3
self.n_vocab = 256
self.embedding_sizes = [32, 32, 3]
self.n_batch = H.n_batch
self.iters_per_epoch = n_train // (mpisize * self.n_batch)
(self.trX, self.trY), (self.vaX, self.vaY), (self.teX, self.teY) = cifar10('/root/data/cifar10/', one_hot=False, test_size=H.test_size)
if H.datapoints:
logprint(f'Only using {H.datapoints} examples')
self.trX = self.trX[:n_train]
self.trY = self.trY[:n_train]
self.shape = [-1, 3072]
self.ctx = 32 * 32 * 3
assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
self.initialize_image_embedding()