def _parse_mnist()

in sample_info/modules/data_utils.py [0:0]


    def _parse_mnist(self, args, build_loaders=True):
        args = copy.deepcopy(args)
        num_train_examples = args.pop('num_train_examples', None)

        from nnlib.nnlib.data_utils.mnist import MNIST
        data_builder = MNIST(**args)

        train_data, val_data, test_data, info = data_builder.build_datasets(**args)
        train_data = BinaryDatasetWrapper(train_data, which_labels=(4, 9))
        val_data = BinaryDatasetWrapper(val_data, which_labels=(4, 9))
        test_data = BinaryDatasetWrapper(test_data, which_labels=(4, 9))

        # trim down validation and training sets to num_train_examples
        if num_train_examples is not None:
            np.random.seed(args.get('seed', 42))
            if len(train_data) > num_train_examples:
                train_indices = np.random.choice(len(train_data), size=num_train_examples, replace=False)
                train_data = SubsetDataWrapper(train_data, include_indices=train_indices)

            if len(val_data) > num_train_examples:
                val_indices = np.random.choice(len(val_data), size=num_train_examples, replace=False)
                val_data = SubsetDataWrapper(val_data, include_indices=val_indices)

        # add label noise
        seed = args.get('seed', 42)
        error_prob = args.get('error_prob', 0.0)
        if error_prob > 0.0:
            train_data = UniformNoiseWrapper(train_data, error_prob=error_prob, num_classes=2, seed=seed)
            info = train_data.is_corrupted
            clean_validation = args.get('clean_validation', True)
            if not clean_validation:
                val_data = UniformNoiseWrapper(val_data, error_prob=error_prob, num_classes=2, seed=seed)

        if not build_loaders:
            return train_data, val_data, test_data, info

        train_loader, val_loader, test_loader = nnlib.nnlib.data_utils.base.get_loaders_from_datasets(
            train_data=train_data, val_data=val_data, test_data=test_data, **args)

        return train_loader, val_loader, test_loader, info