in datasets/datasets.py [0:0]
def download_mnist(self):
"""Skips download if cache is available"""
train_set = torchvision.datasets.MNIST(
"/tmp/",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
test_set = torchvision.datasets.MNIST(
"/tmp/",
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
(
train_data,
train_targets,
valid_data,
valid_targets,
) = ProjectiveMNIST.split_train_valid(train_set)
# stratified samples
train_data = self.sample_single_digit(
train_data, train_targets, self.train_set_proportion
)
valid_data = self.sample_single_digit(
valid_data, valid_targets, self.valid_set_proportion
)
test_data = self.sample_single_digit(
test_set.data, test_set.targets, self.test_set_proportion
)
return train_data, valid_data, test_data