def get_dataset()

in svhn_based_exp/utils.py [0:0]


def get_dataset(task, para_to_vary_model, seed=0, data_root="~/data"):
    transform_train = svhn_transform_train
    transform_test = svhn_transform_test
    trainset = torchvision.datasets.SVHN(root=data_root, split='train', download=True, transform=transform_train)
    testset = torchvision.datasets.SVHN(root=data_root, split='test', download=True, transform=transform_test)

    #modify labels for different tasks
    if task == "dd-mis-up":
        def modify_label(dataset):
            selected_idx_0 = np.concatenate([np.nonzero(dataset.labels == i)[0] for i in [1, 0, 7, 4, 9]])
            selected_idx_1 = np.concatenate([np.nonzero(dataset.labels == i)[0] for i in [6, 2, 3, 5, 8]])
            dataset.labels[selected_idx_0] = 0
            dataset.labels[selected_idx_1] = 1
            return dataset
        trainset = modify_label(trainset)
        testset = modify_label(testset)
    elif task in ["uu-ent-down", "uu-ent-up1", "uu-ent-up2"]:
        np.random.seed(0)
        def select_data(dataset):
            selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 7, 6]]
            class_size = min(len(a) for a in selected_idx)
            selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
            dataset.labels = dataset.labels[selected_idx]
            dataset.data = dataset.data[selected_idx]
            return dataset
        trainset = select_data(trainset)
        testset = select_data(testset)
        if task == "uu-ent-down":
            trainset.labels[trainset.labels == 1] = 1
            trainset.labels[trainset.labels == 6] = 0
            trainset.labels[trainset.labels == 7] = 1
            testset.labels[testset.labels == 1] = 1
            testset.labels[testset.labels == 6] = 0
            testset.labels[testset.labels == 7] = 1
    elif task == "dd-mis-down":
        def select_data(dataset):
            selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 7]]
            for i in range(len(selected_idx)):
                dataset.labels[selected_idx[i]] = i
            selected_idx = np.concatenate(selected_idx)
            dataset.labels = dataset.labels[selected_idx]
            dataset.data = dataset.data[selected_idx]
            return dataset
        trainset = select_data(trainset)
        testset = select_data(testset)
    elif task == "lf-mis-up":
        np.random.seed(0)
        def select_data(dataset):
            selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 3, 7, 9]]
            class_size = min(len(a) for a in selected_idx)
            selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
            dataset.labels = dataset.labels[selected_idx]
            dataset.data = dataset.data[selected_idx]
            return dataset
        trainset = select_data(trainset)
        testset = select_data(testset)
    elif task == "lf-mis-down":
        np.random.seed(0)
        def select_data(dataset):
            selected_idx = [np.nonzero(dataset.labels == i)[0] for i in [1, 3, 7, 9]]
            for i in range(len(selected_idx)):
                dataset.labels[selected_idx[i]] = i
            class_size = min(len(a) for a in selected_idx)
            selected_idx = np.concatenate([a[np.random.choice(len(a), class_size, replace=False)] for a in selected_idx])
            dataset.labels = dataset.labels[selected_idx]
            dataset.data = dataset.data[selected_idx]
            return dataset
        trainset = select_data(trainset)
        testset = select_data(testset)

    #vary training sets
    if para_to_vary_model is not None:
        if task in ["uu-ent-up1", "uu-ent-up2"]:
            FE_rank = torch.load("aul_for_uu-ent.pth")
            selected_idx = FE_rank[:int(0.5 * para_to_vary_model * len(FE_rank))]
            trainset.labels[selected_idx] = (8 - trainset.labels[selected_idx]).astype(np.int)
            if task in "uu-ent-up1":
                trainset.labels[trainset.labels == 1] = 1
                trainset.labels[trainset.labels == 6] = 0
                trainset.labels[trainset.labels == 7] = 0
                testset.labels[testset.labels == 1] = 1
                testset.labels[testset.labels == 6] = 0
                testset.labels[testset.labels == 7] = 0
            else:
                trainset.labels[trainset.labels == 1] = 0
                trainset.labels[trainset.labels == 6] = 0
                trainset.labels[trainset.labels == 7] = 1
                testset.labels[testset.labels == 1] = 0
                testset.labels[testset.labels == 6] = 0
                testset.labels[testset.labels == 7] = 1
        elif task in ["dd-mis-up", "hs-mis-up"]:
            np.random.seed(seed)
            rand_perm = np.random.permutation(len(trainset))
            sampled_idx = rand_perm[:int(para_to_vary_model * len(trainset))]
            trainset.labels = trainset.labels[sampled_idx]
            trainset.data = trainset.data[sampled_idx]
        elif task == "lf-mis-up":
            np.random.seed(seed)
            def subsample(dataset, para_to_vary_model):
                rand_perm = np.random.permutation(len(dataset))
                clean_idx = rand_perm[: int(0.1 * len(rand_perm))]
                dirty_idx = rand_perm[int(0.1 * len(rand_perm)) : int((para_to_vary_model) * len(rand_perm))]
                dataset.data = dataset.data[np.concatenate([clean_idx, dirty_idx])]
                clean_label = dataset.labels[clean_idx]
                dirty_label = dataset.labels[dirty_idx]
                dirty_idx_flip = np.concatenate([np.nonzero(dirty_label == 1)[0], np.nonzero(dirty_label == 3)[0]])
                dirty_idx_flip = dirty_idx_flip[np.random.choice( len(dirty_idx_flip), int(len(dirty_idx_flip) * 0.5), replace=False)]
                dirty_label[dirty_idx_flip] = 4 - dirty_label[dirty_idx_flip]

                dirty_idx_flip = np.concatenate([np.nonzero(dirty_label == 7)[0], np.nonzero(dirty_label == 9)[0]])
                dirty_idx_flip = dirty_idx_flip[np.random.choice( len(dirty_idx_flip), int(len(dirty_idx_flip) * 0.5), replace=False)]
                dirty_label[dirty_idx_flip] = 16 - dirty_label[dirty_idx_flip]
                dataset.labels[np.concatenate([clean_idx, dirty_idx])] = np.concatenate([clean_label, dirty_label])
                dataset.labels = dataset.labels[np.concatenate([clean_idx, dirty_idx])]
                return dataset
            trainset = subsample(trainset, para_to_vary_model)
    return trainset, testset