fn test_fused_matmul()

in src/lib.rs [888:905]


    fn test_fused_matmul() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?;
        let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let cublaslt = CublasLt::new(&device)?;

        let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?;
        let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?;

        assert_eq!(
            to_vec2_round(res.to_dtype(DType::F32)?, 4)?,
            to_vec2_round(expected.to_dtype(DType::F32)?, 4)?
        );
        Ok(())
    }