in utils.py [0:0]
def replace_batchnorm(net):
for child_name, child in net.named_children():
if hasattr(child, 'fuse'):
setattr(net, child_name, child.fuse())
elif isinstance(child, torch.nn.Conv2d):
child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
elif isinstance(child, torch.nn.BatchNorm2d):
setattr(net, child_name, torch.nn.Identity())
else:
replace_batchnorm(child)