in src/lib.rs [13:63]
fn fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
let dev = x.device();
let dtype = x.dtype();
let internal_type = match dtype {
DType::F16 => 0,
DType::BF16 => 1,
DType::F32 => 2,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
if !x_l.is_contiguous() {
candle::bail!("x must be contiguous");
}
// Get cuda slices for all tensors
let x = x.as_cuda_slice::<T>()?;
// Get cuda views for all tensors
let x = x.slice(x_l.start_offset()..);
let dst_shape = x_l.shape().clone();
let elems = dst_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elems) }.w()?;
let x_ptr = *x.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
const NUM_THREADS: c_int = 1024;
let num_blocks = (elems as c_int + NUM_THREADS - 1) / NUM_THREADS;
unsafe {
ffi::silu(
x_ptr,
dst_ptr,
num_blocks,
NUM_THREADS,
elems as c_int,
internal_type,
)
}
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
Ok((dst, dst_shape))
}