def init_weights_vit()

in pycls/models/vit.py [0:0]


def init_weights_vit(model):
    """Performs ViT weight init."""
    for k, m in model.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            if "patchify" in k:
                # ViT patchify stem init
                fan_in = m.in_channels * m.kernel_size[0] * m.kernel_size[1]
                init.trunc_normal_(m.weight, std=math.sqrt(1.0 / fan_in))
                init.zeros_(m.bias)
            elif "cstem_last" in k:
                # The last 1x1 conv of the conv stem
                init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / m.out_channels))
                init.zeros_(m.bias)
            elif "cstem" in k:
                # Use default pytorch init for other conv layers in the C-stem
                pass
            else:
                raise NotImplementedError
        if isinstance(m, torch.nn.Linear):
            if "self_attention" in k:
                # Use default pytorch init for multi-head attention module
                pass
            elif "mlp_block" in k:
                # MLP block init
                init.xavier_uniform_(m.weight)
                init.normal_(m.bias, std=1e-6)
            elif "head_fc" in k:
                # Head (classifier) init
                init.zeros_(m.weight)
                init.zeros_(m.bias)
            else:
                raise NotImplementedError
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.LayerNorm):
            # Use default pytorch init for norm layers
            pass
    # Pos-embedding init
    init.normal_(model.pos_embedding, mean=0.0, std=0.02)