in src/lib.rs [376:407]
fn layer_norm_truth(
x: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
epsilon: f64,
rms: bool,
) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_seq_len, hidden_size) = x.shape().dims2()?;
let x = x.to_dtype(internal_dtype)?;
let x = if !rms {
let mean_x = (x.sum_keepdim(1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = (x.sqr()?.sum_keepdim(1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + epsilon)?.sqrt()?)?;
let mut x = x_normed.to_dtype(x_dtype)?.broadcast_mul(gamma)?;
if let Some(beta) = beta {
x = x.broadcast_add(beta)?;
}
Ok(x)
}