in models/modules/dynamic_ops.py [0:0]
def forward(self, x):
feature_dim = x.size(1)
if not self.training:
raise ValueError('DynamicBN only supports training')
bn = self.bn
# need_sync
if not self.need_sync:
return F.batch_norm(
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
bn.momentum, bn.eps,
)
else:
assert dist.get_world_size() > 1, 'SyncBatchNorm requires >1 world size'
B, C = x.shape[0], x.shape[1]
mean = torch.mean(x, dim=[0, 2, 3])
meansqr = torch.mean(x * x, dim=[0, 2, 3])
assert B > 0, 'does not support zero batch size'
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + bn.eps)
scale = bn.weight[:feature_dim] * invstd
bias = bn.bias[:feature_dim] - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias