def test()

in trainers/swa_endpoint_ensembles.py [0:0]


def test(models, writer, criterion, data_loader, epoch):
    j = args.j
    n = len(models)
    print(args.t)
    print((1 - args.t) / (n - 1))
    print("--")
    for ms in zip(*[model.modules() for model in models]):
        if isinstance(ms[0], nn.Conv2d):
            if j == 0:
                ms[0].weight.data = ms[0].weight.data * args.t
            else:
                ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].weight.data += ms[i].weight.data * args.t
                else:
                    ms[0].weight.data += (
                        ms[i].weight.data * (1 - args.t) / (n - 1)
                    )
            print("conv", ms[0].weight[0, 0, 0, 0])
        elif isinstance(ms[0], nn.BatchNorm2d):
            if j == 0:
                ms[0].weight.data = ms[0].weight.data * args.t
            else:
                ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].weight.data += ms[i].weight.data * args.t
                else:
                    ms[0].weight.data += (
                        ms[i].weight.data * (1 - args.t) / (n - 1)
                    )

            if j == 0:
                ms[0].bias.data = ms[0].bias.data * args.t
            else:
                ms[0].bias.data = ms[0].bias.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].bias.data += ms[i].bias.data * args.t
                else:
                    ms[0].bias.data += ms[i].bias.data * (1 - args.t) / (n - 1)

            print("bn", ms[0].weight[0])
            print("bn", ms[0].bias[0])

    utils.update_bn(data_loader.train_loader, models[0], device=args.device)

    # here was save the model in args.tmp_dir/model_{j}.pt
    torch.save(
        {
            "epoch": 0,
            "iter": 0,
            "arch": args.model,
            "state_dicts": [models[0].state_dict()],
            "optimizers": None,
            "best_acc1": 0,
            "curr_acc1": 0,
        },
        os.path.join(args.tmp_dir, f"model_{j}.pt"),
    )

    test_acc = 0
    metrics = {}

    return test_acc, metrics