def get_upstream_preprocessing()

in svhn_based_exp/utils.py [0:0]


def get_upstream_preprocessing(trainloader, testloader, upstream_type, save_dir, device):
    batch_size = trainloader.batch_size
    upstream_setting = upstream_type.split("_")
    num_upstream_model = int(len(upstream_setting) / 6)
    upstream_nets = []
    repres = []
    models = []
    for i in range(num_upstream_model):
        linear_base = next(iter(testloader))[0].view(batch_size, -1).size()[1]
        upstream_nets.append(get_upstream_model(upstream_setting[i * 6 : (i + 1) * 6], linear_base=linear_base, save_dir=save_dir, device=device))
        repres.append(upstream_setting[i * 6 + 4])
        models.append(upstream_setting[i * 6 + 2])

    new_trainset = construct_tensor_dataset(trainloader, upstream_nets, repres, models, batch_size, device=device)
    new_testset = construct_tensor_dataset(testloader, upstream_nets, repres, models, batch_size, device=device)
    trainloader = torch.utils.data.DataLoader(new_trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(new_testset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader