in src/lib.rs [33:199]
fn fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
r: Option<&candle::CudaStorage>,
r_l: Option<&Layout>,
) -> Result<(candle::CudaStorage, Shape)> {
// Assume all tensors are on the same device and take device of x
let dev = x.device();
// Get internal layer norm type id for the given dtype
let layer_norm_type = layer_norm_internal_type(x.dtype())?;
// Make sure that gamma is a CUDA tensor and get the underlying storage
let (g, g_l) = self.gamma.storage_and_layout();
let g = match &*g {
Storage::Cuda(g) => g,
_ => candle::bail!("gamma must be a cuda tensor"),
};
// Get cuda slices for all tensors
let x = x.as_cuda_slice::<T>()?;
let g = g.as_cuda_slice::<T>()?;
// Get cuda views for all tensors
let x = x.slice(x_l.start_offset()..);
let g = g.slice(g_l.start_offset()..);
// Input matrix layout
let rows = x_l.dims()[0];
let cols = x_l.dims()[1];
if !(cols % 8 == 0 && cols <= 8192) {
candle::bail!("hidden size must be % 8 and <= 8192")
}
let x_stride = x_l.stride();
let g_stride = g_l.stride();
let x_rank = x_stride.len();
let g_rank = g_stride.len();
if x_rank != 2 {
candle::bail!("layer-norm expects input tensors of rank 2. Found: {x_rank}")
}
if x_stride[x_rank - 1] != 1 {
candle::bail!("the last dim of x must be contiguous {x_stride:?}")
}
if g_stride[g_rank - 1] != 1 {
candle::bail!("the last dim of g must be contiguous {g_stride:?}")
}
// Round cols to match with the correct kernel
let cols_rounded = if cols <= 1536 {
round_multiple(cols, 256)
} else if cols <= 3072 {
round_multiple(cols, 512)
} else {
round_multiple(cols, 1024)
};
let is_rms_norm = if self.is_rms_norm { 1 } else { 0 };
// If beta is et, get ids device pointer
let b_ptr = if let Some(beta) = &self.beta {
// Make sure that beta is a CUDA tensor and get the underlying storage
let (b, b_l) = beta.storage_and_layout();
let b = match &*b {
Storage::Cuda(b) => b,
_ => candle::bail!("gamma must be a cuda tensor"),
};
let b = b.as_cuda_slice::<T>()?;
let b = b.slice(b_l.start_offset()..);
let b_stride = b_l.stride();
let b_rank = b_stride.len();
if b_stride[b_rank - 1] != 1 {
candle::bail!("the last dim of b must be contiguous {b_stride:?}")
}
*b.device_ptr() as *const core::ffi::c_void
} else {
ptr::null() as *const std::ffi::c_void
};
// If residual is set, get its device pointer
let r_ptr = if let (Some(r), Some(r_l)) = (r, r_l) {
// Check shape
let expected_shape = x_l.shape().dims2()?;
if r_l.shape().dims2()? != expected_shape {
candle::bail!("shape mismatch x {:?} and r {:?}", x_l.shape(), r_l.shape());
}
let r = r.as_cuda_slice::<T>()?;
let r = r.slice(r_l.start_offset()..);
let r_stride = r_l.stride();
let r_rank = r_stride.len();
if r_rank != 2 {
candle::bail!("layer-norm expects input tensors of rank 2. Found: {r_rank}")
}
if r_stride[r_rank - 1] != 1 {
candle::bail!("the last dim of r must be contiguous {r_stride:?}")
}
*r.device_ptr() as *const std::ffi::c_void
} else {
ptr::null() as *const std::ffi::c_void
};
// We will store the results of the residual add next to the main results
// so out has the same shape as inp * 2
let out_shape = Shape::from((rows * 2, cols));
let out = unsafe { dev.alloc::<T>(out_shape.elem_count()) }.w()?;
let dst = out.slice(..rows * cols);
let dst_add = out.slice(rows * cols..);
// Alloc internal buffers
let mu = unsafe { dev.alloc::<f32>(rows) }.w()?;
let rsigma = unsafe { dev.alloc::<f32>(rows) }.w()?;
// Get cuda device pointers from cuda slices
let x_ptr = *x.device_ptr() as *const core::ffi::c_void;
let g_ptr = *g.device_ptr() as *const core::ffi::c_void;
let dst_add_ptr = *dst_add.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let mu_ptr = *mu.device_ptr() as *const core::ffi::c_void;
let rsigma_ptr = *rsigma.device_ptr() as *const core::ffi::c_void;
let multi_processors_count = dev
.attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
.unwrap();
unsafe {
// Launch Kernel
ffi::run_ln(
x_ptr,
r_ptr,
g_ptr,
b_ptr,
dst_add_ptr,
dst_ptr,
mu_ptr,
rsigma_ptr,
self.epsilon,
cols_rounded as u32,
rows as u32,
cols as u32,
multi_processors_count,
layer_norm_type,
layer_norm_type,
layer_norm_type,
layer_norm_type,
2,
is_rms_norm,
)
}
let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok((out, out_shape))
}