def __init__()

in datasets.py [0:0]


    def __init__(self, H, logprint):
        self.logprint = logprint
        self.H = H
        # 1281167 << dataset has this many examples
        # We will use 10k examples for dev
        n_train = 1281167 - 10000
        self.full_dataset_train = True
        self.full_dataset_valid = True
        self.n_batch = H.n_batch
        self.orig_shape = [-1, 32, 32, 3]
        self.trY = None
        self.vaY = None
        self.teY = None
        self.n_classes = None
        self.orig_pixels = 32 * 32 * 3
        self.num_embeddings = 3
        self.n_vocab = 256
        self.embedding_sizes = [32, 32, 3]
        self.iters_per_epoch = n_train // (mpisize * self.n_batch)
        # we are dumb and saved imagenet32 in 3x32x32, unlike ImageNet64, which we saved in transposed format, sorry about the inconsistency
        tr = np.load('/root/data/imagenet32-train.npy').reshape([-1, 3, 32, 32]).transpose(
            [0, 2, 3, 1]).reshape([-1, 3072])
        self.trX = tr[:n_train]
        self.vaX = tr[n_train:]
        self.teX = np.load('/root/data/imagenet32-valid.npy').reshape([-1, 3, 32, 32]).transpose(
            [0, 2, 3, 1]).reshape([-1, 3072])
        self.n_vocab = 256
        self.ctx = 3072
        self.shape = [-1, self.ctx]
        assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
        self.initialize_image_embedding()