def __call__()

in training/flax/distil_whisper/layers.py [0:0]


    def __call__(self, x):
        """Applies layer normalization on the input.
        Args:
          x: the inputs
        Returns:
          Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, jnp.float32)
        features = x.shape[-1]
        mean = jnp.mean(x, axis=-1, keepdims=True)
        mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
        var = mean2 - lax.square(mean)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = param_with_axes(
                "scale",
                self.scale_init,
                (features,),
                self.params_dtype,
                axes=("embed",),
            )
            mul = mul * jnp.asarray(scale, self.dtype)
        y = (x - mean) * mul
        if self.use_bias:
            bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
            y = y + jnp.asarray(bias, self.dtype)
        return jnp.asarray(y, self.dtype)