in utils.py [0:0]
def get_bn_accuracy_metrics(model: nn.Module, mean_dict: Dict, var_dict: Dict):
"""Determine how accurate the running_mean and running_var of the BatchNorms
are.
Note that, the test set should be shuffled during training, or these
statistics won't be valid.
Arguments:
model: the network.
mean_dict: A dictionary that looks like:
{'layer_name': [batch1_mean, batch2_mean, ...]}
var_dict: Similar to mean_dict, but with variances.
"""
mean_results = collections.defaultdict(list)
var_results = collections.defaultdict(list)
name_to_module = {name: module for name, module in model.named_modules()}
for name in mean_dict.keys():
module = name_to_module[name]
running_mean = module.running_mean.detach().cpu()
assert isinstance(mean_dict[name], list)
for batch_result in mean_dict[name]:
num_channels = batch_result.shape[0]
mean_abs_diff = (
(batch_result - running_mean[:num_channels]).abs().mean().item()
)
mean_results[name].append(mean_abs_diff)
running_var = module.running_var.detach().cpu()
assert isinstance(var_dict[name], list)
for batch_result in var_dict[name]:
num_channels = batch_result.shape[0]
mean_abs_diff = (
(batch_result - running_var[:num_channels]).abs().mean().item()
)
var_results[name].append(mean_abs_diff)
# For each layer, record the mean and std of the average deviations, for
# both running_mean and running_var.
ret = {}
for name, stats in mean_results.items():
ret[f"{name}_running_mean_MAD_mean"] = torch.tensor(stats).mean().item()
ret[f"{name}_running_mean_MAD_std"] = torch.tensor(stats).std().item()
for name, stats in var_results.items():
ret[f"{name}_running_var_MAD_mean"] = torch.tensor(stats).mean().item()
ret[f"{name}_running_var_MAD_std"] = torch.tensor(stats).std().item()
return {"bn_metrics": ret}