in datasets.py [0:0]
def __init__(self, H, logprint):
self.logprint = logprint
self.H = H
# 1281167 << dataset has this many examples
# We will use 10k examples for dev
n_train = 1281167 - 10000
self.full_dataset_train = True
self.full_dataset_valid = True
self.n_batch = H.n_batch
self.orig_shape = [-1, 32, 32, 3]
self.trY = None
self.vaY = None
self.teY = None
self.n_classes = None
self.orig_pixels = 32 * 32 * 3
self.num_embeddings = 3
self.n_vocab = 256
self.embedding_sizes = [32, 32, 3]
self.iters_per_epoch = n_train // (mpisize * self.n_batch)
# we are dumb and saved imagenet32 in 3x32x32, unlike ImageNet64, which we saved in transposed format, sorry about the inconsistency
tr = np.load('/root/data/imagenet32-train.npy').reshape([-1, 3, 32, 32]).transpose(
[0, 2, 3, 1]).reshape([-1, 3072])
self.trX = tr[:n_train]
self.vaX = tr[n_train:]
self.teX = np.load('/root/data/imagenet32-valid.npy').reshape([-1, 3, 32, 32]).transpose(
[0, 2, 3, 1]).reshape([-1, 3072])
self.n_vocab = 256
self.ctx = 3072
self.shape = [-1, self.ctx]
assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
self.initialize_image_embedding()