fn test_rms_norm_add()

in src/lib.rs [494:508]


    fn test_rms_norm_add() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let (res, res_add) = fused_add_rms_norm(&x, &r, &g, Some(&b), 1e-12)?;
        let truth_add = (x + r)?;
        let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, true)?;
        assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?);
        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }