in timm/models/sequencer.py [0:0]
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
if flax:
# Flax defaults
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
stdv = 1.0 / math.sqrt(module.hidden_size)
for weight in module.parameters():
nn.init.uniform_(weight, -stdv, stdv)
elif hasattr(module, 'init_weights'):
module.init_weights()