fn layer_norm()

in crates/ratchet-core/src/cpu/norm.rs [91:147]


fn layer_norm<T>(
    input: &Tensor,
    scale: &Tensor,
    bias: &Option<Tensor>,
    eps: f32,
    dst: &Tensor,
) -> Result<(), OperationError>
where
    T: TensorDType + Float + NumOps + for<'a> Sum<&'a T>,
{
    let src_shape = input.shape();
    let rank = input.rank();
    let N = src_shape[rank - 1];
    let norm_shape = shape!(N);

    let input = input.to_vec::<T>()?;
    let scale = scale.to_vec::<T>()?;
    let bias = match bias {
        Some(b) => Some(b.to_vec::<T>()?),
        None => None,
    };

    let mut x = input.clone();

    let mu = mean(&x, src_shape, rank - 1);
    let mut mu2 = mu.clone();
    square(&mut mu2);
    let mut x2 = input.clone();
    square(&mut x2);
    let mut x2 = mean(&x2, src_shape, rank - 1);

    sub(&mut x2, &mu2);

    let mut mu_b = vec![T::zero(); x.len()];
    broadcast_vector(&mu, &mut mu_b);
    sub(&mut x, &mu_b);

    let eps_vec = vec![T::from(eps).unwrap(); x2.len()];
    add(&mut x2, &eps_vec);
    rsqrt(&mut x2);

    let mut v = vec![T::zero(); x.len()];
    broadcast_vector(&x2, &mut v);
    mul(&mut x, &v);

    let scale_b = broadcast(&scale, &norm_shape, src_shape);
    mul(&mut x, &scale_b);

    if let Some(bias) = bias {
        let bias_b = broadcast(&bias, &norm_shape, src_shape);
        add(&mut x, &bias_b);
    }

    cpu_store_result(&dst, &x);

    Ok(())
}