def init_weights()

in pycls/utils/net.py [0:0]


def init_weights(m):
    """Performs ResNet style weight initialization."""
    if isinstance(m, nn.Conv2d) or isinstance(m, SymConv2d):
        # Note that there is no bias due to BN
        fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
    elif isinstance(m, TalkConv2d):
        # Note that there is no bias due to BN
        ### uniform init
        fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels * m.params_scale
        ### node specific init
        # fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
        # m.weight.data = m.weight.data*m.init_scale
    elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
        zero_init_gamma = (
                hasattr(m, 'final_bn') and m.final_bn and
                cfg.BN.ZERO_INIT_FINAL_GAMMA
        )
        m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear) or isinstance(m, TalkLinear) or isinstance(m, SymLinear):
        m.weight.data.normal_(mean=0.0, std=0.01)
        if m.bias is not None:
            m.bias.data.zero_()