in crates/ratchet-core/src/cpu/norm.rs [91:147]
fn layer_norm<T>(
input: &Tensor,
scale: &Tensor,
bias: &Option<Tensor>,
eps: f32,
dst: &Tensor,
) -> Result<(), OperationError>
where
T: TensorDType + Float + NumOps + for<'a> Sum<&'a T>,
{
let src_shape = input.shape();
let rank = input.rank();
let N = src_shape[rank - 1];
let norm_shape = shape!(N);
let input = input.to_vec::<T>()?;
let scale = scale.to_vec::<T>()?;
let bias = match bias {
Some(b) => Some(b.to_vec::<T>()?),
None => None,
};
let mut x = input.clone();
let mu = mean(&x, src_shape, rank - 1);
let mut mu2 = mu.clone();
square(&mut mu2);
let mut x2 = input.clone();
square(&mut x2);
let mut x2 = mean(&x2, src_shape, rank - 1);
sub(&mut x2, &mu2);
let mut mu_b = vec![T::zero(); x.len()];
broadcast_vector(&mu, &mut mu_b);
sub(&mut x, &mu_b);
let eps_vec = vec![T::from(eps).unwrap(); x2.len()];
add(&mut x2, &eps_vec);
rsqrt(&mut x2);
let mut v = vec![T::zero(); x.len()];
broadcast_vector(&x2, &mut v);
mul(&mut x, &v);
let scale_b = broadcast(&scale, &norm_shape, src_shape);
mul(&mut x, &scale_b);
if let Some(bias) = bias {
let bias_b = broadcast(&bias, &norm_shape, src_shape);
add(&mut x, &bias_b);
}
cpu_store_result(&dst, &x);
Ok(())
}