def __init__()

in datasets.py [0:0]


    def __init__(self, H, logprint):
        self.logprint = logprint
        self.H = H
        self.full_dataset_train = True
        self.full_dataset_valid = True
        # 5k examples for valid
        n_train = 45000
        if H.datapoints:
            n_train = H.datapoints
        self.n_batch = H.n_batch
        self.iters_per_epoch = n_train // (mpisize * self.n_batch)
        self.orig_shape = [-1, 32, 32, 3]
        self.n_classes = 10
        self.orig_pixels = 32 * 32 * 3
        self.num_embeddings = 3
        self.n_vocab = 256
        self.embedding_sizes = [32, 32, 3]
        self.n_batch = H.n_batch
        self.iters_per_epoch = n_train // (mpisize * self.n_batch)
        (self.trX, self.trY), (self.vaX, self.vaY), (self.teX, self.teY) = cifar10('/root/data/cifar10/', one_hot=False, test_size=H.test_size)
        if H.datapoints:
            logprint(f'Only using {H.datapoints} examples')
            self.trX = self.trX[:n_train]
            self.trY = self.trY[:n_train]
        self.shape = [-1, 3072]
        self.ctx = 32 * 32 * 3
        assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
        self.initialize_image_embedding()