def test_epoch()

in svhn_based_exp/utils.py [0:0]


def test_epoch(criterion, testloader, net, logger, regression, device):
    net.eval()
    batch_size = testloader.batch_size
    test_loss = 0
    correct = 0
    total = 0
    if isinstance(testloader.sampler, data_utils.SubsetRandomSampler):
        test_size = len(testloader.sampler.indices)
    else:
        test_size = len(testloader.dataset)

    loss_detail = torch.zeros([test_size])
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            if regression:
                targets = targets.float()
                loss = criterion(outputs.squeeze(), targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                loss_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = torch.abs(outputs[:, 0] - targets) ** 2
            else:
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                loss_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = 1 - predicted.eq(targets).float()
                correct += predicted.eq(targets).sum().item()

        if regression:
            logger.info('==>>> test loss: {:.6f}'.format(loss_detail.mean()))
            return loss_detail.mean(), loss_detail
        else:
            logger.info('==>>> test loss: {:.6f}, accuracy: {:.4f}, test zero-one loss: {:.4f}'.format(test_loss/(batch_idx+1), 100.*correct/total, loss_detail.mean()))
            return loss_detail.mean(), loss_detail