fn layer_norm_truth()

in src/lib.rs [376:407]


    fn layer_norm_truth(
        x: &Tensor,
        gamma: &Tensor,
        beta: Option<&Tensor>,
        epsilon: f64,
        rms: bool,
    ) -> Result<Tensor> {
        let x_dtype = x.dtype();
        let internal_dtype = match x_dtype {
            DType::F16 | DType::BF16 => DType::F32,
            d => d,
        };

        let (_seq_len, hidden_size) = x.shape().dims2()?;
        let x = x.to_dtype(internal_dtype)?;

        let x = if !rms {
            let mean_x = (x.sum_keepdim(1)? / hidden_size as f64)?;
            x.broadcast_sub(&mean_x)?
        } else {
            x
        };

        let norm_x = (x.sqr()?.sum_keepdim(1)? / hidden_size as f64)?;
        let x_normed = x.broadcast_div(&(norm_x + epsilon)?.sqrt()?)?;

        let mut x = x_normed.to_dtype(x_dtype)?.broadcast_mul(gamma)?;
        if let Some(beta) = beta {
            x = x.broadcast_add(beta)?;
        }
        Ok(x)
    }