fn cuda_fwd()

in candle-cublaslt/src/lib.rs [799:816]


    fn cuda_fwd(
        &self,
        a: &candle::CudaStorage,
        a_l: &Layout,
        b: &candle::CudaStorage,
        b_l: &Layout,
        bias: &candle::CudaStorage,
        bias_l: &Layout,
    ) -> Result<(candle::CudaStorage, Shape)> {
        match a.dtype() {
            candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
            candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)),
            candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)),
            dt => candle::bail!(
                "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})"
            ),
        }
    }