def test()

in trainers/simplex_ensembles.py [0:0]


def test(models, writer, criterion, data_loader, epoch):

    for i, model in enumerate(models):
        model.zero_grad()
        model.eval()

        if args.layerwise:
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    Z = np.random.exponential(scale=1.0, size=args.n)
                    Z = Z / Z.sum()
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])
        else:
            Z = np.random.exponential(scale=1.0, size=args.n)
            Z = Z / Z.sum()
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])

    test_loss = 0
    correct = 0
    corrects = [0 for _ in range(10)]

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    val_loader = data_loader.val_loader

    for m in models:
        utils.update_bn(data_loader.train_loader, m, device=args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            model_output = models[0](data)
            model_pred = model_output.argmax(dim=1, keepdim=True)
            corrects[0] += (
                model_pred.eq(target.view_as(model_pred)).sum().item()
            )
            mean_output = model_output

            for i, m in enumerate(models[1:]):
                model_output = m(data)
                model_pred = model_output.argmax(dim=1, keepdim=True)
                corrects[i + 1] += (
                    model_pred.eq(target.view_as(model_pred)).sum().item()
                )
                mean_output += model_output

            mean_output /= len(models)
            # get the index of the max log-probability
            pred = mean_output.argmax(dim=1, keepdim=True)
            test_loss += criterion(mean_output, target).item()
            correct_vec = pred.eq(target.view_as(pred))
            correct += correct_vec.sum().item()
            soft_output = mean_output.softmax(dim=1)

            for i in range(data.size(0)):
                conf = soft_output[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    corrects_sacled = [
        float(corrects[i]) / len(val_loader.dataset)
        for i in range(len(corrects))
    ]
    metrics = {
        f"model_{i}_acc": corrects_sacled[i]
        for i in range(len(corrects_sacled))
    }

    corrects_sacled = np.array(corrects_sacled)

    metrics["avg_model_acc"] = np.mean(corrects_sacled[corrects_sacled > 0])
    metrics["avg_model_std"] = np.std(corrects_sacled[corrects_sacled > 0])
    metrics["ece"] = ece
    metrics["test_loss"] = test_loss

    return test_acc, metrics