in src/lib.rs [21:230]
fn cuda_fwd_t<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
q: &CudaStorage,
q_l: &Layout,
) -> Result<(CudaStorage, Shape)> {
let dtype = q.dtype();
let internal_type = match dtype {
DType::F16 => 0,
DType::BF16 => 1,
DType::F32 => 2,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
let dev = q.device();
let out_shape = q_l.shape().clone();
let (kc, kc_l) = self.key_cache.storage_and_layout();
let kc = match &*kc {
Storage::Cuda(kc) => kc,
_ => candle::bail!("key_cache must be a cuda tensor"),
};
let (vc, vc_l) = self.value_cache.storage_and_layout();
let vc = match &*vc {
Storage::Cuda(vc) => vc,
_ => candle::bail!("value_cache must be a cuda tensor"),
};
let (bt, bt_l) = self.block_tables.storage_and_layout();
let bt = match &*bt {
Storage::Cuda(bt) => bt,
_ => candle::bail!("block_tables must be a cuda tensor"),
};
let (cl, cl_l) = self.context_lens.storage_and_layout();
let cl = match &*cl {
Storage::Cuda(cl) => cl,
_ => candle::bail!("context_lens must be a cuda tensor"),
};
let q_rank = q_l.stride().len();
let kc_rank = kc_l.stride().len();
let vc_rank = vc_l.stride().len();
if q_rank != 3 {
candle::bail!(
"paged-attention expects `q` tensor to be of rank 3 \
(q: {q_l:?})"
)
}
if kc_rank != 5 {
candle::bail!(
"paged-attention expects `key_cache` tensor to be of rank 5 \
(key_cache: {kc_l:?})"
)
}
if vc_rank != 4 {
candle::bail!(
"paged-attention expects `value_cache` tensor to be of rank 4 \
(value_cache: {vc_l:?})"
)
}
// Get cuda slices for all tensors
let q = q.as_cuda_slice::<T>()?;
let kc = kc.as_cuda_slice::<T>()?;
let vc = vc.as_cuda_slice::<T>()?;
let cl = cl.as_cuda_slice::<u32>()?; // Should be i32!
let bt = bt.as_cuda_slice::<u32>()?; // Should be i32!
// Get cuda views for all tensors
let q = q.slice(q_l.start_offset()..);
let kc = kc.slice(kc_l.start_offset()..);
let vc = vc.slice(vc_l.start_offset()..);
let cl = cl.slice(cl_l.start_offset()..);
let bt = bt.slice(bt_l.start_offset()..);
let (num_seqs, num_heads, head_size) = q_l.shape().dims3()?;
if !(head_size == 64
|| head_size == 80
|| head_size == 96
|| head_size == 112
|| head_size == 128
|| head_size == 256)
{
candle::bail!("`head_size` must be one of 64, 80, 96, 112, 128 or 256");
}
let (num_seqs_bt, max_num_blocks_per_seq) = bt_l.shape().dims2()?;
if num_seqs_bt != num_seqs {
candle::bail!(
"shape mismatch block_tables {:?}, expected {:?}",
bt_l.shape(),
(num_seqs, max_num_blocks_per_seq)
)
}
let (num_blocks, num_kv_heads, head_size_kc, block_size, x) = kc_l.shape().dims5()?;
if head_size_kc != head_size / x {
candle::bail!(
"shape mismatch value_cache {:?}, expected {:?}",
vc_l.shape(),
(num_blocks, num_heads, head_size / x, block_size, x)
)
}
if (num_blocks, num_kv_heads, head_size, block_size) != vc_l.shape().dims4()? {
candle::bail!(
"shape mismatch key_cache {:?} and value_cache {:?}",
kc_l.shape(),
vc_l.shape()
)
}
if (num_seqs) != cl_l.shape().dims1()? {
candle::bail!(
"shape mismatch context_lens {:?}, expected {:?}",
cl_l.shape(),
(num_seqs)
)
}
let q_stride = q_l.stride()[0];
let kv_block_stride = kc_l.stride()[0];
let kv_head_stride = kc_l.stride()[1];
let partition_size = 512;
let max_num_partitions = (self.max_context_len + partition_size - 1) / partition_size;
let use_v1 = (max_num_partitions == 1 || num_seqs * num_heads > 512)
&& partition_size % block_size == 0;
let elem_count = out_shape.elem_count();
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let out_ptr = *out.device_ptr() as *const core::ffi::c_void;
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let kc_ptr = *kc.device_ptr() as *const core::ffi::c_void;
let vc_ptr = *vc.device_ptr() as *const core::ffi::c_void;
let bt_ptr = *bt.device_ptr() as *const core::ffi::c_int;
let cl_ptr = *cl.device_ptr() as *const core::ffi::c_int;
if use_v1 {
unsafe {
ffi::paged_attention_v1(
out_ptr,
q_ptr,
kc_ptr,
vc_ptr,
num_kv_heads as c_int,
self.softmax_scale,
bt_ptr,
cl_ptr,
block_size as c_int,
self.max_context_len as c_int,
num_seqs as c_int,
num_heads as c_int,
head_size as c_int,
max_num_blocks_per_seq as c_int,
q_stride as c_int,
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
)
}
} else {
let tmp_out_shape = Shape::from((num_seqs, num_heads, max_num_partitions, head_size));
let exp_sums_shape = Shape::from((num_seqs, num_heads, max_num_partitions));
let tmp_out = unsafe { dev.alloc::<T>(tmp_out_shape.elem_count()) }.w()?;
let exp_sums = unsafe { dev.alloc::<f32>(exp_sums_shape.elem_count()) }.w()?;
let max_logits = unsafe { dev.alloc::<f32>(exp_sums_shape.elem_count()) }.w()?;
let tmp_out_ptr = *tmp_out.device_ptr() as *const core::ffi::c_void;
let exp_sums_ptr = *exp_sums.device_ptr() as *const f32;
let max_logits_ptr = *max_logits.device_ptr() as *const f32;
unsafe {
ffi::paged_attention_v2(
out_ptr,
exp_sums_ptr,
max_logits_ptr,
tmp_out_ptr,
q_ptr,
kc_ptr,
vc_ptr,
num_kv_heads as c_int,
self.softmax_scale,
bt_ptr,
cl_ptr,
block_size as c_int,
self.max_context_len as c_int,
num_seqs as c_int,
num_heads as c_int,
head_size as c_int,
max_num_blocks_per_seq as c_int,
q_stride as c_int,
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
)
}
}
let out = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok((out, out_shape))
}