in pycls/core/net.py [0:0]
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS)
num_iter = min(num_iter, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
momentums = [bn.momentum for bn in bns]
# Set momentum to 1.0 to compute BN stats that only reflect the current batch
for bn in bns:
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
for i, bn in enumerate(bns):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
# Sync BN stats across GPUs (no reduction if 1 GPU used)
running_means = dist.scaled_all_reduce(running_means)
running_vars = dist.scaled_all_reduce(running_vars)
# Set BN stats and restore original momentum values
for i, bn in enumerate(bns):
bn.running_mean = running_means[i]
bn.running_var = running_vars[i]
bn.momentum = momentums[i]