def __init__()

in examples/vae/utils/mnist_cached.py [0:0]


    def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs):
        super().__init__(train=mode in ["sup", "unsup", "valid"], *args, **kwargs)

        # transformations on MNIST data (normalization and one-hot conversion for labels)
        def transform(x):
            return fn_x_mnist(x, use_cuda)

        def target_transform(y):
            return fn_y_mnist(y, use_cuda)

        self.mode = mode

        assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values"

        if mode in ["sup", "unsup", "valid"]:

            # transform the training data if transformations are provided
            if transform is not None:
                self.data = (transform(self.data.float()))
            if target_transform is not None:
                self.targets = (target_transform(self.targets))

            if MNISTCached.train_data_sup is None:
                if sup_num is None:
                    assert mode == "unsup"
                    MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \
                        self.data, self.targets
                else:
                    MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \
                        MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \
                        MNISTCached.data_valid, MNISTCached.labels_valid = \
                        split_sup_unsup_valid(self.data, self.targets, sup_num)

            if mode == "sup":
                self.data, self.targets = MNISTCached.train_data_sup, MNISTCached.train_labels_sup
            elif mode == "unsup":
                self.data = MNISTCached.train_data_unsup

                # making sure that the unsupervised labels are not available to inference
                self.targets = (torch.Tensor(
                    MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan
            else:
                self.data, self.targets = MNISTCached.data_valid, MNISTCached.labels_valid

        else:
            # transform the testing data if transformations are provided
            if transform is not None:
                self.data = (transform(self.data.float()))
            if target_transform is not None:
                self.targets = (target_transform(self.targets))