def _init_nn_layers()

in cvnets/misc/init_utils.py [0:0]


def _init_nn_layers(module, init_method: Optional[str] = 'kaiming_normal', std_val: Optional[float] = None):
    init_method = init_method.lower()
    if init_method == 'kaiming_normal':
        if module.weight is not None:
            nn.init.kaiming_normal_(module.weight, mode='fan_out')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif init_method == 'kaiming_uniform':
        if module.weight is not None:
            nn.init.kaiming_uniform_(module.weight, mode='fan_out')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif init_method == 'xavier_normal':
        if module.weight is not None:
            nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif init_method == 'xavier_uniform':
        if module.weight is not None:
            nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif init_method == 'normal':
        if module.weight is not None:
            std = 1.0 / module.weight.size(1) if std_val is None else std_val
            nn.init.normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif init_method == 'trunc_normal':
        if module.weight is not None:
            std = 1.0 / module.weight.size(1) if std_val is None else std_val
            nn.init.trunc_normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    else:
        supported_conv_message = 'Supported initialization methods are:'
        for i, l in enumerate(supported_conv_inits):
            supported_conv_message += '\n \t {}) {}'.format(i, l)
        logger.error('{} \n Got: {}'.format(supported_conv_message, init_method))