in pycls/utils/net.py [0:0]
def init_weights(m):
"""Performs ResNet style weight initialization."""
if isinstance(m, nn.Conv2d) or isinstance(m, SymConv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, TalkConv2d):
# Note that there is no bias due to BN
### uniform init
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels * m.params_scale
### node specific init
# fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
# m.weight.data = m.weight.data*m.init_scale
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
zero_init_gamma = (
hasattr(m, 'final_bn') and m.final_bn and
cfg.BN.ZERO_INIT_FINAL_GAMMA
)
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear) or isinstance(m, TalkLinear) or isinstance(m, SymLinear):
m.weight.data.normal_(mean=0.0, std=0.01)
if m.bias is not None:
m.bias.data.zero_()