def __init__()

in datasets.py [0:0]


    def __init__(self, H, logprint):
        self.logprint = logprint
        self.H = H
        # Whether the full dataset is loaded on each rank, or just its own partition
        self.full_dataset_train = True
        self.full_dataset_valid = True
        n_train = 1231149
        self.n_batch = H.n_batch
        self.orig_shape = [-1, 64, 64, 3]
        self.orig_pixels = 64 * 64 * 3
        self.num_embeddings = 3
        self.n_vocab = 256
        self.embedding_sizes = [64, 64, 3]
        self.iters_per_epoch = n_train // (mpisize * self.n_batch)
        tr = np.load('/root/data/imagenet64-train.npy', mmap_mode='r').reshape([-1, 12288])
        self.trX = tr[:n_train]
        self.trY = None
        self.vaY = None
        self.teY = None
        self.vaX = tr[n_train:]
        self.n_classes = None
        self.teX = np.load('/root/data/imagenet64-valid.npy', mmap_mode='r').reshape([-1, 12288])
        self.n_vocab = 256
        self.ctx = 12288
        self.shape = [-1, self.ctx]
        assert self.ctx == H.n_ctx, f'n_ctx should be {self.ctx}'
        self.initialize_image_embedding()