in src/lib.rs [908:935]
fn test_fused_batch_matmul() -> Result<()> {
let device = Device::new_cuda(0)?;
let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?;
let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?;
let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let cublaslt = CublasLt::new(&device)?;
let res = fused_batch_matmul(
&a,
&b,
Some(&c),
None,
Some(1.0),
Some(&bias),
None,
cublaslt,
)?;
let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?;
assert_eq!(
to_vec3_round(res.to_dtype(DType::F32)?, 4)?,
to_vec3_round(expected.to_dtype(DType::F32)?, 4)?
);
Ok(())
}