in crates/ratchet-core/src/ops/norm/mod.rs [407:440]
fn ground_truth(
var: NormVariant,
input: &Tensor,
scale: &Tensor,
bias: Option<&Tensor>,
) -> anyhow::Result<Tensor> {
let ln_prg = r#"
import torch
import torch.nn.functional as F
def layer_norm(input, scale, bias):
(input, scale, bias) = (torch.from_numpy(input), torch.from_numpy(scale), torch.from_numpy(bias))
return F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy()
"#;
let rms_prg = r#"
import torch
def manual_rms_norm(input, scale):
(input, scale) = (torch.from_numpy(input), torch.from_numpy(scale))
variance = input.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
input = input * torch.rsqrt(variance + 1e-5)
return (scale * input).numpy()
"#;
let prg = match var {
NormVariant::LayerNorm => ln_prg,
NormVariant::RMSNorm => rms_prg,
};
let inputs = match bias {
Some(bias) => rvec![input, scale, bias],
None => rvec![input, scale],
};
run_py_prg(prg.to_string(), &inputs, &[], input.dt())
}