trainers/random_average_weights_global.py [34:78]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if isinstance(ms[0], nn.Conv2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
        elif isinstance(ms[0], nn.BatchNorm2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
            ms[0].bias.data = Z[0] * ms[0].bias.data
            for i in range(1, args.num_models):
                ms[0].bias.data += Z[i] * ms[i].bias.data
    model = models[0]
    utils.update_bn(data_loader.train_loader, model, device=args.device)
    # model.train()
    # # for batch_idx, (data, target) in enumerate(data_loader.train_loader):
    # #     data, target = data.to(args.device), target.to(args.device)
    # #     output = model(data)
    model.eval()

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)
            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    metrics = {}

    return test_acc, metrics
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



trainers/random_average_weights_layerwise.py [33:77]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if isinstance(ms[0], nn.Conv2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
        elif isinstance(ms[0], nn.BatchNorm2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
            ms[0].bias.data = Z[0] * ms[0].bias.data
            for i in range(1, args.num_models):
                ms[0].bias.data += Z[i] * ms[i].bias.data
    model = models[0]
    utils.update_bn(data_loader.train_loader, model, device=args.device)
    # model.train()
    # # for batch_idx, (data, target) in enumerate(data_loader.train_loader):
    # #     data, target = data.to(args.device), target.to(args.device)
    # #     output = model(data)
    model.eval()

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)
            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    metrics = {}

    return test_acc, metrics
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



