fn test_fused_batch_matmul()

in candle-cublaslt/src/lib.rs [912:939]


    fn test_fused_batch_matmul() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?;
        let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?;
        let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let cublaslt = CublasLt::new(&device)?;

        let res = fused_batch_matmul(
            &a,
            &b,
            Some(&c),
            None,
            Some(1.0),
            Some(&bias),
            None,
            cublaslt,
        )?;
        let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?;

        assert_eq!(
            to_vec3_round(res.to_dtype(DType::F32)?, 4)?,
            to_vec3_round(expected.to_dtype(DType::F32)?, 4)?
        );
        Ok(())
    }