def test_by_nn()

in svhn_based_exp/utils.py [0:0]


def test_by_nn(testloader, net, model, logger, device):
    net.eval()
    batch_size = testloader.batch_size
    knn_net = deepcopy(net)
    feat_detail = None
    testset_size = len(testloader.dataset)
    target_detail = torch.zeros([testset_size])
    if model.startswith("convnet"):
        knn_net.fc1 = nn.Identity()
    else:
        knn_net.fc = nn.Identity()
    knn_net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = knn_net(inputs)
            if feat_detail is None:
                feat_detail = torch.zeros([testset_size, outputs.size()[1]])
            feat_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = outputs
            target_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = targets
        np.random.seed(0)
        randperm = np.random.permutation(len(testloader.dataset))
        val_split, test_split = randperm[:int(0.5 * testset_size)], randperm[int(0.5 * testset_size):]
        np.random.seed(int(time.time()))
        del knn_net
        val_feat, val_target = feat_detail[val_split].detach().cpu().numpy(), target_detail[val_split]
        test_feat, test_target = feat_detail[test_split].detach().cpu().numpy(), target_detail[test_split]
        test_to_val_dist = get_dist(test_feat, val_feat)
        test_preds = val_target[np.argsort(test_to_val_dist)[:, 0]]
        correct = test_preds.eq(test_target).sum().item()
        total = len(test_split)
        logger.info('==>>>knn zero-one loss: {:.4f}'.format(1 - float(correct)/total))
        return 1 - float(correct)/total