in src/lib.rs [8:130]
fn apply_rotary_<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
query: &Tensor,
key: &Tensor,
cos_cache: &Tensor,
sin_cache: &Tensor,
is_neox: bool,
) -> Result<()> {
let dtype = query.dtype();
if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
candle::bail!("apply-rotary expects all tensors to have the same dtype");
}
let internal_type = match dtype {
DType::F16 => 0,
DType::BF16 => 1,
DType::F32 => 2,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
let (q, q_l) = query.storage_and_layout();
let q = match &*q {
Storage::Cuda(q) => q,
_ => candle::bail!("query must be a cuda tensor"),
};
let (k, k_l) = key.storage_and_layout();
let k = match &*k {
Storage::Cuda(k) => k,
_ => candle::bail!("key must be a cuda tensor"),
};
let (cc, cc_l) = cos_cache.storage_and_layout();
let cc = match &*cc {
Storage::Cuda(cc) => cc,
_ => candle::bail!("cos_cache must be a cuda tensor"),
};
let (sc, sc_l) = sin_cache.storage_and_layout();
let sc = match &*sc {
Storage::Cuda(sc) => sc,
_ => candle::bail!("sin_cache must be a cuda tensor"),
};
let q_rank = q_l.stride().len();
let k_rank = k_l.stride().len();
let cc_rank = cc_l.stride().len();
let sc_rank = sc_l.stride().len();
if q_rank != 3 || k_rank != 3 {
candle::bail!("apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})")
}
if cc_rank != 2 || sc_rank != 2 {
candle::bail!("apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})")
}
// Get cuda slices for all tensors
let q = q.as_cuda_slice::<T>()?;
let k = k.as_cuda_slice::<T>()?;
let cc = cc.as_cuda_slice::<T>()?;
let sc = sc.as_cuda_slice::<T>()?;
// Get cuda views for all tensors
let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..);
let cc = cc.slice(cc_l.start_offset()..);
let sc = sc.slice(sc_l.start_offset()..);
let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;
if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
}
let rot_dim = cc_l.dims()[1];
if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
candle::bail!(
"shape mismatch cos_cache {:?}, expected {:?}",
cc_l.shape(),
(num_tokens, rot_dim)
)
}
if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
candle::bail!(
"shape mismatch sin_cache {:?}, expected {:?}",
sc_l.shape(),
(num_tokens, rot_dim)
)
}
let query_stride = q_l.stride()[0];
let key_stride = k_l.stride()[0];
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void;
let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void;
let neox = if is_neox { 1 } else { 0 };
unsafe {
ffi::rotary_embedding(
q_ptr,
k_ptr,
cc_ptr,
sc_ptr,
neox,
head_size as c_int,
num_tokens as c_long,
rot_dim as c_int,
num_heads as c_int,
num_kv_heads as c_int,
query_stride as c_long,
key_stride as c_long,
internal_type,
)
}
Ok(())
}