in cvnets/misc/init_utils.py [0:0]
def _init_nn_layers(module, init_method: Optional[str] = 'kaiming_normal', std_val: Optional[float] = None):
init_method = init_method.lower()
if init_method == 'kaiming_normal':
if module.weight is not None:
nn.init.kaiming_normal_(module.weight, mode='fan_out')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif init_method == 'kaiming_uniform':
if module.weight is not None:
nn.init.kaiming_uniform_(module.weight, mode='fan_out')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif init_method == 'xavier_normal':
if module.weight is not None:
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif init_method == 'xavier_uniform':
if module.weight is not None:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif init_method == 'normal':
if module.weight is not None:
std = 1.0 / module.weight.size(1) if std_val is None else std_val
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif init_method == 'trunc_normal':
if module.weight is not None:
std = 1.0 / module.weight.size(1) if std_val is None else std_val
nn.init.trunc_normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
supported_conv_message = 'Supported initialization methods are:'
for i, l in enumerate(supported_conv_inits):
supported_conv_message += '\n \t {}) {}'.format(i, l)
logger.error('{} \n Got: {}'.format(supported_conv_message, init_method))