in domainbed/networks.py [0:0]
def remove_batch_norm_from_resnet(model):
fuse = torch.nn.utils.fusion.fuse_conv_bn_eval
model.eval()
model.conv1 = fuse(model.conv1, model.bn1)
model.bn1 = Identity()
for name, module in model.named_modules():
if name.startswith("layer") and len(name) == 6:
for b, bottleneck in enumerate(module):
for name2, module2 in bottleneck.named_modules():
if name2.startswith("conv"):
bn_name = "bn" + name2[-1]
setattr(bottleneck, name2,
fuse(module2, getattr(bottleneck, bn_name)))
setattr(bottleneck, bn_name, Identity())
if isinstance(bottleneck.downsample, torch.nn.Sequential):
bottleneck.downsample[0] = fuse(bottleneck.downsample[0],
bottleneck.downsample[1])
bottleneck.downsample[1] = Identity()
model.train()
return model