def weight_init()

in mtrl/agent/utils.py [0:0]


def weight_init(m: ModelType):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        weight_init_linear(m)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        weight_init_conv(m)
    elif isinstance(m, moe_layer.Linear):
        weight_init_moe_layer(m)