in utils/init.py [0:0]
def init_weights(module, init_linear='normal'):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in module.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=0.01)
else:
c2_msra_fill(m)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)