in pycls/models/vit.py [0:0]
def init_weights_vit(model):
"""Performs ViT weight init."""
for k, m in model.named_modules():
if isinstance(m, torch.nn.Conv2d):
if "patchify" in k:
# ViT patchify stem init
fan_in = m.in_channels * m.kernel_size[0] * m.kernel_size[1]
init.trunc_normal_(m.weight, std=math.sqrt(1.0 / fan_in))
init.zeros_(m.bias)
elif "cstem_last" in k:
# The last 1x1 conv of the conv stem
init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / m.out_channels))
init.zeros_(m.bias)
elif "cstem" in k:
# Use default pytorch init for other conv layers in the C-stem
pass
else:
raise NotImplementedError
if isinstance(m, torch.nn.Linear):
if "self_attention" in k:
# Use default pytorch init for multi-head attention module
pass
elif "mlp_block" in k:
# MLP block init
init.xavier_uniform_(m.weight)
init.normal_(m.bias, std=1e-6)
elif "head_fc" in k:
# Head (classifier) init
init.zeros_(m.weight)
init.zeros_(m.bias)
else:
raise NotImplementedError
if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.LayerNorm):
# Use default pytorch init for norm layers
pass
# Pos-embedding init
init.normal_(model.pos_embedding, mean=0.0, std=0.02)