in datasets.py [0:0]
def __init__(self, H, logprint):
self.logprint = logprint
self.H = H
# Whether the full dataset is loaded on each rank, or just its own partition
self.full_dataset_train = True
self.full_dataset_valid = True
n_train = 1231149
self.n_batch = H.n_batch
self.orig_shape = [-1, 64, 64, 3]
self.orig_pixels = 64 * 64 * 3
self.num_embeddings = 3
self.n_vocab = 256
self.embedding_sizes = [64, 64, 3]
self.iters_per_epoch = n_train // (mpisize * self.n_batch)
tr = np.load('/root/data/imagenet64-train.npy', mmap_mode='r').reshape([-1, 12288])
self.trX = tr[:n_train]
self.trY = None
self.vaY = None
self.teY = None
self.vaX = tr[n_train:]
self.n_classes = None
self.teX = np.load('/root/data/imagenet64-valid.npy', mmap_mode='r').reshape([-1, 12288])
self.n_vocab = 256
self.ctx = 12288
self.shape = [-1, self.ctx]
assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
self.initialize_image_embedding()