in training/flax/distil_whisper/layers.py [0:0]
def __call__(self, x):
"""Applies layer normalization on the input.
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
var = mean2 - lax.square(mean)
mul = lax.rsqrt(var + self.epsilon)
if self.use_scale:
scale = param_with_axes(
"scale",
self.scale_init,
(features,),
self.params_dtype,
axes=("embed",),
)
mul = mul * jnp.asarray(scale, self.dtype)
y = (x - mean) * mul
if self.use_bias:
bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
y = y + jnp.asarray(bias, self.dtype)
return jnp.asarray(y, self.dtype)