in svhn_based_exp/utils.py [0:0]
def get_dataset(task, para_to_vary_model, seed=0, data_root="~/data"):
transform_train = svhn_transform_train
transform_test = svhn_transform_test
trainset = torchvision.datasets.SVHN(root=data_root, split='train', download=True, transform=transform_train)
testset = torchvision.datasets.SVHN(root=data_root, split='test', download=True, transform=transform_test)
#modify labels for different tasks
if task == "dd-mis-up":
def modify_label(dataset):
selected_idx_0 = np.concatenate([np.nonzero(dataset.labels == i)[0] for i in [1, 0, 7, 4, 9]])
selected_idx_1 = np.concatenate([np.nonzero(dataset.labels == i)[0] for i in [6, 2, 3, 5, 8]])
dataset.labels[selected_idx_0] = 0
dataset.labels[selected_idx_1] = 1
return dataset
trainset = modify_label(trainset)
testset = modify_label(testset)
elif task in ["uu-ent-down", "uu-ent-up1", "uu-ent-up2"]:
np.random.seed(0)
def select_data(dataset):
selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 7, 6]]
class_size = min(len(a) for a in selected_idx)
selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
dataset.labels = dataset.labels[selected_idx]
dataset.data = dataset.data[selected_idx]
return dataset
trainset = select_data(trainset)
testset = select_data(testset)
if task == "uu-ent-down":
trainset.labels[trainset.labels == 1] = 1
trainset.labels[trainset.labels == 6] = 0
trainset.labels[trainset.labels == 7] = 1
testset.labels[testset.labels == 1] = 1
testset.labels[testset.labels == 6] = 0
testset.labels[testset.labels == 7] = 1
elif task == "dd-mis-down":
def select_data(dataset):
selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 7]]
for i in range(len(selected_idx)):
dataset.labels[selected_idx[i]] = i
selected_idx = np.concatenate(selected_idx)
dataset.labels = dataset.labels[selected_idx]
dataset.data = dataset.data[selected_idx]
return dataset
trainset = select_data(trainset)
testset = select_data(testset)
elif task == "lf-mis-up":
np.random.seed(0)
def select_data(dataset):
selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 3, 7, 9]]
class_size = min(len(a) for a in selected_idx)
selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
dataset.labels = dataset.labels[selected_idx]
dataset.data = dataset.data[selected_idx]
return dataset
trainset = select_data(trainset)
testset = select_data(testset)
elif task == "lf-mis-down":
np.random.seed(0)
def select_data(dataset):
selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 3, 7, 9]]
for i in range(len(selected_idx)):
dataset.labels[selected_idx[i]] = i
class_size = min(len(a) for a in selected_idx)
selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
dataset.labels = dataset.labels[selected_idx]
dataset.data = dataset.data[selected_idx]
return dataset
trainset = select_data(trainset)
testset = select_data(testset)
#vary training sets
if para_to_vary_model is not None:
if task in ["uu-ent-up1", "uu-ent-up2"]:
FE_rank = torch.load("aul_for_uu-ent.pth")
selected_idx = FE_rank[:int(0.5 * para_to_vary_model * len(FE_rank))]
trainset.labels[selected_idx] = (8 - trainset.labels[selected_idx]).astype(np.int)
if task in "uu-ent-up1":
trainset.labels[trainset.labels == 1] = 1
trainset.labels[trainset.labels == 6] = 0
trainset.labels[trainset.labels == 7] = 0
testset.labels[testset.labels == 1] = 1
testset.labels[testset.labels == 6] = 0
testset.labels[testset.labels == 7] = 0
else:
trainset.labels[trainset.labels == 1] = 0
trainset.labels[trainset.labels == 6] = 0
trainset.labels[trainset.labels == 7] = 1
testset.labels[testset.labels == 1] = 0
testset.labels[testset.labels == 6] = 0
testset.labels[testset.labels == 7] = 1
elif task in ["dd-mis-up", "hs-mis-up"]:
np.random.seed(seed)
rand_perm = np.random.permutation(len(trainset))
sampled_idx = rand_perm[:int(para_to_vary_model * len(trainset))]
trainset.labels = trainset.labels[sampled_idx]
trainset.data = trainset.data[sampled_idx]
elif task == "lf-mis-up":
np.random.seed(seed)
def subsample(dataset, para_to_vary_model):
rand_perm = np.random.permutation(len(dataset))
clean_idx = rand_perm[: int(0.1 * len(rand_perm))]
dirty_idx = rand_perm[int(0.1 * len(rand_perm)) : int((para_to_vary_model) * len(rand_perm))]
dataset.data = dataset.data[np.concatenate([clean_idx, dirty_idx])]
clean_label = dataset.labels[clean_idx]
dirty_label = dataset.labels[dirty_idx]
dirty_idx_flip = np.concatenate([np.nonzero(dirty_label == 1)[0], np.nonzero(dirty_label == 3)[0]])
dirty_idx_flip = dirty_idx_flip[np.random.choice( len(dirty_idx_flip), int(len(dirty_idx_flip) * 0.5), replace=False)]
dirty_label[dirty_idx_flip] = 4 - dirty_label[dirty_idx_flip]
dirty_idx_flip = np.concatenate([np.nonzero(dirty_label == 7)[0], np.nonzero(dirty_label == 9)[0]])
dirty_idx_flip = dirty_idx_flip[np.random.choice( len(dirty_idx_flip), int(len(dirty_idx_flip) * 0.5), replace=False)]
dirty_label[dirty_idx_flip] = 16 - dirty_label[dirty_idx_flip]
dataset.labels[np.concatenate([clean_idx, dirty_idx])] = np.concatenate([clean_label, dirty_label])
dataset.labels = dataset.labels[np.concatenate([clean_idx, dirty_idx])]
return dataset
trainset = subsample(trainset, para_to_vary_model)
return trainset, testset