def fuse_resnet_params()

in py/examine_resnet.py [0:0]


def fuse_resnet_params(m):
    resnet.Bottleneck.forward = new_forward
    resnet.ResNet.forward = new_resnet_forward

    m.fused = True
    fuse_bn(m.conv1, m.bn1)
    del m.bn1
    for seq in [m.layer1, m.layer2, m.layer3, m.layer4]:
        seq.fused = True
        for bb in seq:
            bb.fused = True
            fuse_bn(bb.conv1, bb.bn1)
            del bb.bn1
            fuse_bn(bb.conv2, bb.bn2)
            del bb.bn2
            if (hasattr(bb, 'conv3')):
                fuse_bn(bb.conv3, bb.bn3)
                del bb.bn3
            if (bb.downsample):
                fuse_bn(bb.downsample[0], bb.downsample[1])
                del bb.downsample[1]