def mlp_init()

in shap_e/models/nn/ops.py [0:0]


def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):
    if init == "siren30":
        for idx, affine in enumerate(affines):
            init = siren_init_first_layer if idx == 0 else siren_init_30
            init(affine, init_scale=init_scale)
    elif init == "siren":
        for idx, affine in enumerate(affines):
            init = siren_init_first_layer if idx == 0 else siren_init
            init(affine, init_scale=init_scale)
    elif init is None:
        for affine in affines:
            std_init(affine, init_scale=init_scale)
    else:
        raise NotImplementedError(init)