fn compute()

in src/main.rs [545:570]


fn compute<T: VariablePrecisionFloat>(
    handle: &cublas::safe::CudaBlas,
    a: &CudaSlice<T>,
    b: &CudaSlice<T>,
    out: &mut CudaSlice<T>,
) -> anyhow::Result<()>
where
    CudaBlas: Gemm<T>,
{
    let cfg = GemmConfig {
        transa: cublasOperation_t::CUBLAS_OP_N,
        transb: cublasOperation_t::CUBLAS_OP_N,
        m: SIZE as i32,
        n: SIZE as i32,
        k: SIZE as i32,
        alpha: T::from_f32(1.0),
        lda: SIZE as i32,
        ldb: SIZE as i32,
        beta: T::from_f32(0.0),
        ldc: SIZE as i32,
    };
    unsafe {
        handle.gemm(cfg, a, b, out)?;
    }
    Ok(())
}