fn fwd<()

in candle-layer-norm/src/lib.rs [33:199]


    fn fwd<
        T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
    >(
        &self,
        x: &candle::CudaStorage,
        x_l: &Layout,
        r: Option<&candle::CudaStorage>,
        r_l: Option<&Layout>,
    ) -> Result<(candle::CudaStorage, Shape)> {
        // Assume all tensors are on the same device and take device of x
        let dev = x.device();

        // Get internal layer norm type id for the given dtype
        let layer_norm_type = layer_norm_internal_type(x.dtype())?;

        // Make sure that gamma is a CUDA tensor and get the underlying storage
        let (g, g_l) = self.gamma.storage_and_layout();
        let g = match &*g {
            Storage::Cuda(g) => g,
            _ => candle::bail!("gamma must be a cuda tensor"),
        };

        // Get cuda slices for all tensors
        let x = x.as_cuda_slice::<T>()?;
        let g = g.as_cuda_slice::<T>()?;

        // Get cuda views for all tensors
        let x = x.slice(x_l.start_offset()..);
        let g = g.slice(g_l.start_offset()..);

        // Input matrix layout
        let rows = x_l.dims()[0];
        let cols = x_l.dims()[1];

        if !(cols % 8 == 0 && cols <= 8192) {
            candle::bail!("hidden size must be % 8 and <= 8192")
        }

        let x_stride = x_l.stride();
        let g_stride = g_l.stride();

        let x_rank = x_stride.len();
        let g_rank = g_stride.len();

        if x_rank != 2 {
            candle::bail!("layer-norm expects input tensors of rank 2. Found: {x_rank}")
        }
        if x_stride[x_rank - 1] != 1 {
            candle::bail!("the last dim of x must be contiguous {x_stride:?}")
        }
        if g_stride[g_rank - 1] != 1 {
            candle::bail!("the last dim of g must be contiguous {g_stride:?}")
        }

        // Round cols to match with the correct kernel
        let cols_rounded = if cols <= 1536 {
            round_multiple(cols, 256)
        } else if cols <= 3072 {
            round_multiple(cols, 512)
        } else {
            round_multiple(cols, 1024)
        };

        let is_rms_norm = if self.is_rms_norm { 1 } else { 0 };

        // If beta is et, get ids device pointer
        let b_ptr = if let Some(beta) = &self.beta {
            // Make sure that beta is a CUDA tensor and get the underlying storage
            let (b, b_l) = beta.storage_and_layout();
            let b = match &*b {
                Storage::Cuda(b) => b,
                _ => candle::bail!("gamma must be a cuda tensor"),
            };

            let b = b.as_cuda_slice::<T>()?;
            let b = b.slice(b_l.start_offset()..);

            let b_stride = b_l.stride();
            let b_rank = b_stride.len();

            if b_stride[b_rank - 1] != 1 {
                candle::bail!("the last dim of b must be contiguous {b_stride:?}")
            }
            *b.device_ptr() as *const core::ffi::c_void
        } else {
            ptr::null() as *const std::ffi::c_void
        };

        // If residual is set, get its device pointer
        let r_ptr = if let (Some(r), Some(r_l)) = (r, r_l) {
            // Check shape
            let expected_shape = x_l.shape().dims2()?;
            if r_l.shape().dims2()? != expected_shape {
                candle::bail!("shape mismatch x {:?} and r {:?}", x_l.shape(), r_l.shape());
            }

            let r = r.as_cuda_slice::<T>()?;
            let r = r.slice(r_l.start_offset()..);

            let r_stride = r_l.stride();
            let r_rank = r_stride.len();

            if r_rank != 2 {
                candle::bail!("layer-norm expects input tensors of rank 2. Found: {r_rank}")
            }

            if r_stride[r_rank - 1] != 1 {
                candle::bail!("the last dim of r must be contiguous {r_stride:?}")
            }
            *r.device_ptr() as *const std::ffi::c_void
        } else {
            ptr::null() as *const std::ffi::c_void
        };

        // We will store the results of the residual add next to the main results
        // so out has the same shape as inp * 2
        let out_shape = Shape::from((rows * 2, cols));

        let out = unsafe { dev.alloc::<T>(out_shape.elem_count()) }.w()?;
        let dst = out.slice(..rows * cols);
        let dst_add = out.slice(rows * cols..);

        // Alloc internal buffers
        let mu = unsafe { dev.alloc::<f32>(rows) }.w()?;
        let rsigma = unsafe { dev.alloc::<f32>(rows) }.w()?;

        // Get cuda device pointers from cuda slices
        let x_ptr = *x.device_ptr() as *const core::ffi::c_void;
        let g_ptr = *g.device_ptr() as *const core::ffi::c_void;
        let dst_add_ptr = *dst_add.device_ptr() as *const core::ffi::c_void;
        let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
        let mu_ptr = *mu.device_ptr() as *const core::ffi::c_void;
        let rsigma_ptr = *rsigma.device_ptr() as *const core::ffi::c_void;

        let multi_processors_count = dev
            .attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
            .unwrap();

        unsafe {
            // Launch Kernel
            ffi::run_ln(
                x_ptr,
                r_ptr,
                g_ptr,
                b_ptr,
                dst_add_ptr,
                dst_ptr,
                mu_ptr,
                rsigma_ptr,
                self.epsilon,
                cols_rounded as u32,
                rows as u32,
                cols as u32,
                multi_processors_count,
                layer_norm_type,
                layer_norm_type,
                layer_norm_type,
                layer_norm_type,
                2,
                is_rms_norm,
            )
        }

        let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone());

        Ok((out, out_shape))
    }