in pycls/utils/net.py [0:0]
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]