in trainers/train_simplexes.py [0:0]
def test(models, writer, criterion, data_loader, epoch):
model = models[0]
model.eval()
test_loss = 0
correct0 = 0
wa_correct = 0
val_loader = data_loader.val_loader
for i in range(1, args.n):
model.apply(lambda m: setattr(m, f"t{i}", 1.0 / args.n))
utils.update_bn(data_loader.train_loader, model, args.device)
model.eval()
cossim, l2 = get_stats(model)
M = 20
acc_bm = torch.zeros(M)
conf_bm = torch.zeros(M)
count_bm = torch.zeros(M)
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(args.device), target.to(args.device)
wa_output = model(data)
soft_output = wa_output.softmax(dim=1)
wa_pred = wa_output.argmax(dim=1, keepdim=True)
correct_vec = wa_pred.eq(target.view_as(wa_pred))
wa_correct += correct_vec.sum().item()
test_loss += criterion(wa_output, target).item()
for i in range(data.size(0)):
conf = soft_output[i][wa_pred[i]]
bin_idx = min((conf * M).int().item(), M - 1)
acc_bm[bin_idx] += correct_vec[i].float().item()
conf_bm[bin_idx] += conf.item()
count_bm[bin_idx] += 1.0
wa_acc = float(wa_correct) / len(val_loader.dataset)
m0_acc = float(correct0) / len(val_loader.dataset)
test_acc = wa_acc
test_loss /= len(val_loader)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
)
if args.save:
writer.add_scalar(f"test/loss", test_loss, epoch)
writer.add_scalar(f"test/norm", l2, epoch)
writer.add_scalar(f"test/cossim", cossim, epoch)
writer.add_scalar(f"test/wa_acc", wa_acc, epoch)
writer.add_scalar(f"test/m0_acc", m0_acc, epoch)
ece = 0.0
for i in range(M):
ece += (acc_bm[i] - conf_bm[i]).abs().item()
ece /= len(val_loader.dataset)
print("ece is", ece)
metrics = {
"ece": ece,
"wa_acc": wa_acc,
"m0_acc": m0_acc,
"l2": l2,
"cossim": cossim,
"test_loss": test_loss,
}
return test_acc, metrics