in candle-cublaslt/src/lib.rs [371:386]
fn cuda_fwd(
&self,
a: &candle::CudaStorage,
a_l: &Layout,
b: &candle::CudaStorage,
b_l: &Layout,
bias: &candle::CudaStorage,
bias_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match a.dtype() {
candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)),
dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"),
}
}