in svhn_based_exp/utils.py [0:0]
def test_epoch(criterion, testloader, net, logger, regression, device):
net.eval()
batch_size = testloader.batch_size
test_loss = 0
correct = 0
total = 0
if isinstance(testloader.sampler, data_utils.SubsetRandomSampler):
test_size = len(testloader.sampler.indices)
else:
test_size = len(testloader.dataset)
loss_detail = torch.zeros([test_size])
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
if regression:
targets = targets.float()
loss = criterion(outputs.squeeze(), targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
loss_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = torch.abs(outputs[:, 0] - targets) ** 2
else:
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
loss_detail[batch_idx * batch_size : (batch_idx + 1) * batch_size] = 1 - predicted.eq(targets).float()
correct += predicted.eq(targets).sum().item()
if regression:
logger.info('==>>> test loss: {:.6f}'.format(loss_detail.mean()))
return loss_detail.mean(), loss_detail
else:
logger.info('==>>> test loss: {:.6f}, accuracy: {:.4f}, test zero-one loss: {:.4f}'.format(test_loss/(batch_idx+1), 100.*correct/total, loss_detail.mean()))
return loss_detail.mean(), loss_detail