in segmentation/util/util.py [0:0]
def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'):
"""
:param model: Pytorch Model which is nn.Module
:param conv: 'kaiming' or 'xavier'
:param batchnorm: 'normal' or 'constant'
:param linear: 'kaiming' or 'xavier'
:param lstm: 'kaiming' or 'xavier'
"""
for m in model.modules():
if isinstance(m, (nn.modules.conv._ConvNd)):
if conv == 'kaiming':
initer.kaiming_normal_(m.weight)
elif conv == 'xavier':
initer.xavier_normal_(m.weight)
else:
raise ValueError("init type of conv error.\n")
if m.bias is not None:
initer.constant_(m.bias, 0)
elif isinstance(m, (nn.modules.batchnorm._BatchNorm)):
if batchnorm == 'normal':
initer.normal_(m.weight, 1.0, 0.02)
elif batchnorm == 'constant':
initer.constant_(m.weight, 1.0)
else:
raise ValueError("init type of batchnorm error.\n")
initer.constant_(m.bias, 0.0)
elif isinstance(m, nn.Linear):
if linear == 'kaiming':
initer.kaiming_normal_(m.weight)
elif linear == 'xavier':
initer.xavier_normal_(m.weight)
else:
raise ValueError("init type of linear error.\n")
if m.bias is not None:
initer.constant_(m.bias, 0)
elif isinstance(m, nn.LSTM):
for name, param in m.named_parameters():
if 'weight' in name:
if lstm == 'kaiming':
initer.kaiming_normal_(param)
elif lstm == 'xavier':
initer.xavier_normal_(param)
else:
raise ValueError("init type of lstm error.\n")
elif 'bias' in name:
initer.constant_(param, 0)