def test()

in trainers/average_weights.py [0:0]


def test(models, writer, criterion, data_loader, epoch):
    for m in models:
        m.eval()
    model = models[0]
    test_loss = 0
    correct = 0
    val_loader = data_loader.val_loader

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    for ms in zip(*[models[i].modules() for i in range(args.num_models)]):
        if isinstance(ms[0], nn.Conv2d):
            ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += (1.0 / args.num_models) * ms[i].weight.data
        elif isinstance(ms[0], nn.BatchNorm2d):
            ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += (1.0 / args.num_models) * ms[i].weight.data
            ms[0].bias.data = (1.0 / args.num_models) * ms[0].bias.data
            for i in range(1, args.num_models):
                ms[0].bias.data += (1.0 / args.num_models) * ms[i].bias.data

    utils.update_bn(data_loader.train_loader, model, device=args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)
            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            soft_out = output.softmax(dim=1)
            correct_vec = pred.eq(target.view_as(pred))

            correct += correct_vec.sum().item()

            for i in range(data.size(0)):
                conf = soft_out[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    metrics = {"ece": ece, "test_loss": test_loss}

    return test_acc, metrics