src/lib.rs (761 lines of code) (raw):

pub use cudarc::cublaslt::Activation; use std::ffi::c_int; use candle::backend::BackendStorage; use candle::cuda_backend::WrapErr; use candle::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; use half::{bf16, f16}; use std::sync::Arc; use cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; #[derive(Debug, Clone)] pub struct CublasLt(Arc<CudaBlasLT>); impl CublasLt { pub fn new(device: &Device) -> Result<Self> { let dev = match &*device { Device::Cuda(d) => d, _ => candle::bail!("`device` must be a `cuda` device"), }; let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); Ok(Self(Arc::new(inner))) } } pub struct CublasLTMatmul { pub cublaslt: Arc<CudaBlasLT>, pub act: Option<Activation>, pub c: Option<Tensor>, pub alpha: Option<f32>, pub beta: Option<f32>, } impl CublasLTMatmul { pub fn fwd_f16( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (m, k) = a_l.shape().dims2()?; let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { candle::bail!("This layer only supports TN layout"); } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((n, m)); let a = a.as_cuda_slice::<f16>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<f16>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<f16>()?.slice(bias_l.start_offset()..)) } else { None }; let mut out = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<f16>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { candle::bail!("`c` does not have the correct shape"); } c.clone() } else { // Allocate out tensor unsafe { dev.alloc::<f16>(out_shape.elem_count()).w()? } }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: None, stride_b: None, stride_c: None, stride_bias: None, batch_size: None, }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_bf16( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (m, k) = a_l.shape().dims2()?; let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { candle::bail!("This layer only supports TN layout"); } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((n, m)); let a = a.as_cuda_slice::<bf16>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<bf16>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<bf16>()?.slice(bias_l.start_offset()..)) } else { None }; let mut out = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<bf16>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { candle::bail!("`c` does not have the correct shape"); } c.clone() } else { // Allocate out tensor unsafe { dev.alloc::<bf16>(out_shape.elem_count()).w()? } }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: None, stride_b: None, stride_c: None, stride_bias: None, batch_size: None, }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_f32( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (m, k) = a_l.shape().dims2()?; let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { candle::bail!("This layer only supports TN layout"); } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((n, m)); let a = a.as_cuda_slice::<f32>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<f32>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<f32>()?.slice(bias_l.start_offset()..)) } else { None }; let mut out = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<f32>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { candle::bail!("`c` does not have the correct shape"); } c.clone() } else { // Allocate out tensor unsafe { dev.alloc::<f32>(out_shape.elem_count()).w()? } }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: None, stride_b: None, stride_c: None, stride_bias: None, batch_size: None, }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } } impl candle::CustomOp2 for CublasLTMatmul { fn name(&self) -> &'static str { "cublaslt-matmul" } fn cpu_fwd( &self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { candle::bail!("no cpu support for cublaslt-matmul") } fn cuda_fwd( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match a.dtype() { candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), } } } impl candle::CustomOp3 for CublasLTMatmul { fn name(&self) -> &'static str { "cublaslt-matmul-add" } fn cpu_fwd( &self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { candle::bail!("no cpu support for cublaslt-matmul") } fn cuda_fwd( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: &candle::CudaStorage, bias_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match a.dtype() { candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), } } } /// Fused matmul + add + Relu/Gelu activation using CublasLt /// /// # Arguments /// /// * `a` - Input tensor of size MxK /// * `b` - Input tensor of size NxK /// * `out` - Optional Output tensor of size NxK. /// If set and beta != 0, will be added to the end result of A*B before `act` /// * `alpha` - Optional scaling factor for A*B /// * `beta` - Optional scaling factor for C /// * `bias` - Optional bias tensor of size M /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result /// * `cublaslt` - CublasLt handle /// /// The resulting tensor is of shape NxM pub fn fused_matmul( a: &Tensor, b: &Tensor, out: Option<&Tensor>, alpha: Option<f32>, beta: Option<f32>, bias: Option<&Tensor>, act: Option<Activation>, cublaslt: CublasLt, ) -> Result<Tensor> { let op = CublasLTMatmul { act, cublaslt: cublaslt.0, c: out.cloned(), alpha, beta, }; if let Some(bias) = bias { a.apply_op3(&b, &bias, op) } else { a.apply_op2(&b, op) } } pub struct CublasLTBatchMatmul { pub cublaslt: Arc<CudaBlasLT>, pub act: Option<Activation>, pub c: Option<Tensor>, pub alpha: Option<f32>, pub beta: Option<f32>, } impl CublasLTBatchMatmul { pub fn fwd_f16( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (batch_size, m, k) = a_l.shape().dims3()?; let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { candle::bail!("This layer only supports TN layout"); } if b_0 != batch_size { candle::bail!("`b` must have the same batch size as `a`") } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((batch_size, n, m)); let a = a.as_cuda_slice::<f16>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<f16>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<f16>()?.slice(bias_l.start_offset()..)) } else { None }; let (mut out, stride_c) = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<f16>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { candle::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set (c.clone(), c_l.stride()[0]) } else { // Allocate out tensor ( unsafe { dev.alloc::<f16>(out_shape.elem_count()).w()? }, (n * m), ) }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: Some(a_l.stride()[0] as i64), stride_b: Some(b_l.stride()[0] as i64), stride_c: Some(stride_c as i64), stride_bias: None, batch_size: Some(batch_size as c_int), }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_bf16( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (batch_size, m, k) = a_l.shape().dims3()?; let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { candle::bail!("This layer only supports TN layout"); } if b_0 != batch_size { candle::bail!("`b` must have the same batch size as `a`") } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((batch_size, n, m)); let a = a.as_cuda_slice::<bf16>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<bf16>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<bf16>()?.slice(bias_l.start_offset()..)) } else { None }; let (mut out, stride_c) = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<bf16>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { candle::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set (c.clone(), c_l.stride()[0]) } else { // Allocate out tensor ( unsafe { dev.alloc::<bf16>(out_shape.elem_count()).w()? }, (n * m), ) }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: Some(a_l.stride()[0] as i64), stride_b: Some(b_l.stride()[0] as i64), stride_c: Some(stride_c as i64), stride_bias: None, batch_size: Some(batch_size as c_int), }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_f32( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: Option<&candle::CudaStorage>, bias_l: Option<&Layout>, ) -> Result<(candle::CudaStorage, Shape)> { let dev = a.device(); // Assume TN let (batch_size, m, k) = a_l.shape().dims3()?; let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { candle::bail!("This layer only supports TN layout"); } if b_0 != batch_size { candle::bail!("`b` must have the same batch size as `a`") } let lda = k; let ldb = k; let ldc = m; let out_shape = Shape::from((batch_size, n, m)); let a = a.as_cuda_slice::<f32>()?.slice(a_l.start_offset()..); let b = b.as_cuda_slice::<f32>()?.slice(b_l.start_offset()..); let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { candle::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::<f32>()?.slice(bias_l.start_offset()..)) } else { None }; let (mut out, stride_c) = if let Some(c) = &self.c { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::<f32>()?, _ => candle::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { candle::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { candle::bail!("`c` end offset must be {}", out_shape.elem_count()) } } None => candle::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { candle::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set (c.clone(), c_l.stride()[0]) } else { // Allocate out tensor ( unsafe { dev.alloc::<f32>(out_shape.elem_count()).w()? }, (n * m), ) }; let config = MatmulConfig { transa: true, transb: false, m: m as u64, n: n as u64, k: k as u64, alpha: self.alpha.unwrap_or(1.0), lda: lda as i64, ldb: ldb as i64, beta: self.beta.unwrap_or(0.0), ldc: ldc as i64, stride_a: Some(a_l.stride()[0] as i64), stride_b: Some(b_l.stride()[0] as i64), stride_c: Some(stride_c as i64), stride_bias: None, batch_size: Some(batch_size as c_int), }; unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) .map_err(|e| candle::Error::Cuda(Box::new(e)))?; } let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } } impl candle::CustomOp2 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul" } fn cpu_fwd( &self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { candle::bail!("no cpu support for cublaslt-batch-matmul") } fn cuda_fwd( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match a.dtype() { candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), dt => { candle::bail!("cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})") } } } } impl candle::CustomOp3 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul-add" } fn cpu_fwd( &self, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { candle::bail!("no cpu support for cublaslt-batch-matmul-add") } fn cuda_fwd( &self, a: &candle::CudaStorage, a_l: &Layout, b: &candle::CudaStorage, b_l: &Layout, bias: &candle::CudaStorage, bias_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match a.dtype() { candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), dt => candle::bail!( "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" ), } } } /// Fused batch matmul + add + Relu/Gelu activation using CublasLt /// /// # Arguments /// /// * `a` - Input tensor of size BxMxK /// * `b` - Input tensor of size BxNxK /// * `out` - Optional Output tensor of size BxNxK. /// If set and beta != 0, will be added to the end result of A*B before `act` /// * `alpha` - Optional scaling factor for A*B /// * `beta` - Optional scaling factor for C /// * `bias` - Optional bias tensor of size M /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result /// * `cublaslt` - CublasLt handle /// /// The resulting tensor is of shape NxM pub fn fused_batch_matmul( a: &Tensor, b: &Tensor, out: Option<&Tensor>, alpha: Option<f32>, beta: Option<f32>, bias: Option<&Tensor>, act: Option<Activation>, cublaslt: CublasLt, ) -> Result<Tensor> { let op = CublasLTBatchMatmul { act, cublaslt: cublaslt.0, c: out.cloned(), alpha, beta, }; if let Some(bias) = bias { a.apply_op3(&b, &bias, op) } else { a.apply_op2(&b, op) } } #[cfg(test)] mod tests { use super::*; use candle::{DType, Device}; fn to_vec2_round(t: Tensor, digits: i32) -> Result<Vec<Vec<f32>>> { let b = 10f32.powi(digits); let t = t.to_vec2::<f32>()?; let t = t .iter() .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) .collect(); Ok(t) } fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { let b = 10f32.powi(digits); let t = t.to_vec3::<f32>()?; let t = t .iter() .map(|t| { t.iter() .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) .collect() }) .collect(); Ok(t) } #[test] fn test_fused_matmul() -> Result<()> { let device = Device::new_cuda(0)?; let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?; let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?; let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; let cublaslt = CublasLt::new(&device)?; let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?; let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?; assert_eq!( to_vec2_round(res.to_dtype(DType::F32)?, 4)?, to_vec2_round(expected.to_dtype(DType::F32)?, 4)? ); Ok(()) } #[test] fn test_fused_batch_matmul() -> Result<()> { let device = Device::new_cuda(0)?; let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?; let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?; let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?; let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; let cublaslt = CublasLt::new(&device)?; let res = fused_batch_matmul( &a, &b, Some(&c), None, Some(1.0), Some(&bias), None, cublaslt, )?; let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?; assert_eq!( to_vec3_round(res.to_dtype(DType::F32)?, 4)?, to_vec3_round(expected.to_dtype(DType::F32)?, 4)? ); Ok(()) } }