def variance_scaling()

in training/flax/distil_whisper/layers.py [0:0]


def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
    """Inlined JAX `nn.initializer.variance_scaling`."""

    def init(key, shape, dtype=dtype):
        return jnp.zeros(shape, dtype=dtype)
        dtype = jax.dtypes.canonicalize_dtype(dtype)
        shape = jax.core.as_named_shape(shape)
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
        if mode == "fan_in":
            denominator = fan_in
        elif mode == "fan_out":
            denominator = fan_out
        elif mode == "fan_avg":
            denominator = (fan_in + fan_out) / 2
        else:
            raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
        variance = jnp.array(scale / denominator, dtype=dtype)

        if distribution == "truncated_normal":
            # constant is stddev of standard normal truncated to (-2, 2)
            stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
            return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
        elif distribution == "normal":
            return random.normal(key, shape, dtype) * jnp.sqrt(variance)
        elif distribution == "uniform":
            return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
        else:
            raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))

    return init