in src/main.rs [545:570]
fn compute<T: VariablePrecisionFloat>(
handle: &cublas::safe::CudaBlas,
a: &CudaSlice<T>,
b: &CudaSlice<T>,
out: &mut CudaSlice<T>,
) -> anyhow::Result<()>
where
CudaBlas: Gemm<T>,
{
let cfg = GemmConfig {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: SIZE as i32,
n: SIZE as i32,
k: SIZE as i32,
alpha: T::from_f32(1.0),
lda: SIZE as i32,
ldb: SIZE as i32,
beta: T::from_f32(0.0),
ldc: SIZE as i32,
};
unsafe {
handle.gemm(cfg, a, b, out)?;
}
Ok(())
}