in src/lib.rs [888:905]
fn test_fused_matmul() -> Result<()> {
let device = Device::new_cuda(0)?;
let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?;
let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let cublaslt = CublasLt::new(&device)?;
let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?;
let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?;
assert_eq!(
to_vec2_round(res.to_dtype(DType::F32)?, 4)?,
to_vec2_round(expected.to_dtype(DType::F32)?, 4)?
);
Ok(())
}