in candle-cublaslt/src/lib.rs [764:779]
fn cuda_fwd(
&self,
a: &candle::CudaStorage,
a_l: &Layout,
b: &candle::CudaStorage,
b_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match a.dtype() {
candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None),
candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None),
candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None),
dt => {
candle::bail!("cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})")
}
}
}