in src/lib.rs [798:815]
fn cuda_fwd(
&self,
a: &candle::CudaStorage,
a_l: &Layout,
b: &candle::CudaStorage,
b_l: &Layout,
bias: &candle::CudaStorage,
bias_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match a.dtype() {
candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)),
dt => candle::bail!(
"cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})"
),
}
}