in pytorchvideo/models/weight_init.py [0:0]
def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None:
"""
Weight initialization for vision transformers.
Args:
model (nn.Module): Model to be initialized.
trunc_normal_std (float): the expected standard deviation for fully-connected
layer and ClsPositionalEncoding.
"""
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=trunc_normal_std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, SpatioTemporalClsPositionalEncoding):
for weights in m.parameters():
nn.init.trunc_normal_(weights, std=trunc_normal_std)