def test()

in trainers/train_simplexes.py [0:0]


def test(models, writer, criterion, data_loader, epoch):

    model = models[0]
    model.eval()
    test_loss = 0
    correct0 = 0
    wa_correct = 0
    val_loader = data_loader.val_loader
    for i in range(1, args.n):
        model.apply(lambda m: setattr(m, f"t{i}", 1.0 / args.n))

    utils.update_bn(data_loader.train_loader, model, args.device)
    model.eval()
    cossim, l2 = get_stats(model)

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            wa_output = model(data)
            soft_output = wa_output.softmax(dim=1)
            wa_pred = wa_output.argmax(dim=1, keepdim=True)
            correct_vec = wa_pred.eq(target.view_as(wa_pred))
            wa_correct += correct_vec.sum().item()
            test_loss += criterion(wa_output, target).item()

            for i in range(data.size(0)):
                conf = soft_output[i][wa_pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    wa_acc = float(wa_correct) / len(val_loader.dataset)
    m0_acc = float(correct0) / len(val_loader.dataset)
    test_acc = wa_acc
    test_loss /= len(val_loader)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/norm", l2, epoch)
        writer.add_scalar(f"test/cossim", cossim, epoch)

        writer.add_scalar(f"test/wa_acc", wa_acc, epoch)
        writer.add_scalar(f"test/m0_acc", m0_acc, epoch)

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    metrics = {
        "ece": ece,
        "wa_acc": wa_acc,
        "m0_acc": m0_acc,
        "l2": l2,
        "cossim": cossim,
        "test_loss": test_loss,
    }

    return test_acc, metrics