in svhn_based_exp/utils.py [0:0]
def test_by_nn(testloader, net, model, logger, device):
net.eval()
batch_size = testloader.batch_size
knn_net = deepcopy(net)
feat_detail = None
testset_size = len(testloader.dataset)
target_detail = torch.zeros([testset_size])
if model.startswith("convnet"):
knn_net.fc1 = nn.Identity()
else:
knn_net.fc = nn.Identity()
knn_net.eval()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = knn_net(inputs)
if feat_detail is None:
feat_detail = torch.zeros([testset_size, outputs.size()[1]])
feat_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = outputs
target_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = targets
np.random.seed(0)
randperm = np.random.permutation(len(testloader.dataset))
val_split, test_split = randperm[:int(0.5 * testset_size)], randperm[int(0.5 * testset_size):]
np.random.seed(int(time.time()))
del knn_net
val_feat, val_target = feat_detail[val_split].detach().cpu().numpy(), target_detail[val_split]
test_feat, test_target = feat_detail[test_split].detach().cpu().numpy(), target_detail[test_split]
test_to_val_dist = get_dist(test_feat, val_feat)
test_preds = val_target[np.argsort(test_to_val_dist)[:, 0]]
correct = test_preds.eq(test_target).sum().item()
total = len(test_split)
logger.info('==>>>knn zero-one loss: {:.4f}'.format(1 - float(correct)/total))
return 1 - float(correct)/total