def get_bn_accuracy_metrics()

in utils.py [0:0]


def get_bn_accuracy_metrics(model: nn.Module, mean_dict: Dict, var_dict: Dict):
    """Determine how accurate the running_mean and running_var of the BatchNorms
    are.

    Note that, the test set should be shuffled during training, or these
    statistics won't be valid.

    Arguments:
        model: the network.
        mean_dict: A dictionary that looks like:
            {'layer_name': [batch1_mean, batch2_mean, ...]}
        var_dict: Similar to mean_dict, but with variances.
    """
    mean_results = collections.defaultdict(list)
    var_results = collections.defaultdict(list)

    name_to_module = {name: module for name, module in model.named_modules()}

    for name in mean_dict.keys():
        module = name_to_module[name]

        running_mean = module.running_mean.detach().cpu()
        assert isinstance(mean_dict[name], list)
        for batch_result in mean_dict[name]:
            num_channels = batch_result.shape[0]
            mean_abs_diff = (
                (batch_result - running_mean[:num_channels]).abs().mean().item()
            )
            mean_results[name].append(mean_abs_diff)

        running_var = module.running_var.detach().cpu()
        assert isinstance(var_dict[name], list)
        for batch_result in var_dict[name]:
            num_channels = batch_result.shape[0]
            mean_abs_diff = (
                (batch_result - running_var[:num_channels]).abs().mean().item()
            )
            var_results[name].append(mean_abs_diff)

    # For each layer, record the mean and std of the average deviations, for
    # both running_mean and running_var.
    ret = {}
    for name, stats in mean_results.items():
        ret[f"{name}_running_mean_MAD_mean"] = torch.tensor(stats).mean().item()
        ret[f"{name}_running_mean_MAD_std"] = torch.tensor(stats).std().item()
    for name, stats in var_results.items():
        ret[f"{name}_running_var_MAD_mean"] = torch.tensor(stats).mean().item()
        ret[f"{name}_running_var_MAD_std"] = torch.tensor(stats).std().item()
    return {"bn_metrics": ret}