in training/flax/distil_whisper/layers.py [0:0]
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
"""Inlined JAX `nn.initializer.variance_scaling`."""
def init(key, shape, dtype=dtype):
return jnp.zeros(shape, dtype=dtype)
dtype = jax.dtypes.canonicalize_dtype(dtype)
shape = jax.core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = jnp.array(scale / denominator, dtype=dtype)
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))
return init