in sample_info/modules/data_utils.py [0:0]
def _parse_mnist(self, args, build_loaders=True):
args = copy.deepcopy(args)
num_train_examples = args.pop('num_train_examples', None)
from nnlib.nnlib.data_utils.mnist import MNIST
data_builder = MNIST(**args)
train_data, val_data, test_data, info = data_builder.build_datasets(**args)
train_data = BinaryDatasetWrapper(train_data, which_labels=(4, 9))
val_data = BinaryDatasetWrapper(val_data, which_labels=(4, 9))
test_data = BinaryDatasetWrapper(test_data, which_labels=(4, 9))
# trim down validation and training sets to num_train_examples
if num_train_examples is not None:
np.random.seed(args.get('seed', 42))
if len(train_data) > num_train_examples:
train_indices = np.random.choice(len(train_data), size=num_train_examples, replace=False)
train_data = SubsetDataWrapper(train_data, include_indices=train_indices)
if len(val_data) > num_train_examples:
val_indices = np.random.choice(len(val_data), size=num_train_examples, replace=False)
val_data = SubsetDataWrapper(val_data, include_indices=val_indices)
# add label noise
seed = args.get('seed', 42)
error_prob = args.get('error_prob', 0.0)
if error_prob > 0.0:
train_data = UniformNoiseWrapper(train_data, error_prob=error_prob, num_classes=2, seed=seed)
info = train_data.is_corrupted
clean_validation = args.get('clean_validation', True)
if not clean_validation:
val_data = UniformNoiseWrapper(val_data, error_prob=error_prob, num_classes=2, seed=seed)
if not build_loaders:
return train_data, val_data, test_data, info
train_loader, val_loader, test_loader = nnlib.nnlib.data_utils.base.get_loaders_from_datasets(
train_data=train_data, val_data=val_data, test_data=test_data, **args)
return train_loader, val_loader, test_loader, info