fn apply_rotary_<()

in src/lib.rs [8:130]


fn apply_rotary_<
    T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
    query: &Tensor,
    key: &Tensor,
    cos_cache: &Tensor,
    sin_cache: &Tensor,
    is_neox: bool,
) -> Result<()> {
    let dtype = query.dtype();
    if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
        candle::bail!("apply-rotary expects all tensors to have the same dtype");
    }

    let internal_type = match dtype {
        DType::F16 => 0,
        DType::BF16 => 1,
        DType::F32 => 2,
        dtype => candle::bail!("dtype {dtype:?} is not supported"),
    };

    let (q, q_l) = query.storage_and_layout();
    let q = match &*q {
        Storage::Cuda(q) => q,
        _ => candle::bail!("query must be a cuda tensor"),
    };

    let (k, k_l) = key.storage_and_layout();
    let k = match &*k {
        Storage::Cuda(k) => k,
        _ => candle::bail!("key must be a cuda tensor"),
    };

    let (cc, cc_l) = cos_cache.storage_and_layout();
    let cc = match &*cc {
        Storage::Cuda(cc) => cc,
        _ => candle::bail!("cos_cache must be a cuda tensor"),
    };

    let (sc, sc_l) = sin_cache.storage_and_layout();
    let sc = match &*sc {
        Storage::Cuda(sc) => sc,
        _ => candle::bail!("sin_cache must be a cuda tensor"),
    };

    let q_rank = q_l.stride().len();
    let k_rank = k_l.stride().len();
    let cc_rank = cc_l.stride().len();
    let sc_rank = sc_l.stride().len();

    if q_rank != 3 || k_rank != 3 {
        candle::bail!("apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})")
    }

    if cc_rank != 2 || sc_rank != 2 {
        candle::bail!("apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})")
    }

    // Get cuda slices for all tensors
    let q = q.as_cuda_slice::<T>()?;
    let k = k.as_cuda_slice::<T>()?;
    let cc = cc.as_cuda_slice::<T>()?;
    let sc = sc.as_cuda_slice::<T>()?;

    // Get cuda views for all tensors
    let q = q.slice(q_l.start_offset()..);
    let k = k.slice(k_l.start_offset()..);
    let cc = cc.slice(cc_l.start_offset()..);
    let sc = sc.slice(sc_l.start_offset()..);

    let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
    let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;

    if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
        candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
    }

    let rot_dim = cc_l.dims()[1];
    if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
        candle::bail!(
            "shape mismatch cos_cache {:?}, expected {:?}",
            cc_l.shape(),
            (num_tokens, rot_dim)
        )
    }

    if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
        candle::bail!(
            "shape mismatch sin_cache {:?}, expected {:?}",
            sc_l.shape(),
            (num_tokens, rot_dim)
        )
    }

    let query_stride = q_l.stride()[0];
    let key_stride = k_l.stride()[0];

    let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
    let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
    let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void;
    let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void;

    let neox = if is_neox { 1 } else { 0 };

    unsafe {
        ffi::rotary_embedding(
            q_ptr,
            k_ptr,
            cc_ptr,
            sc_ptr,
            neox,
            head_size as c_int,
            num_tokens as c_long,
            rot_dim as c_int,
            num_heads as c_int,
            num_kv_heads as c_int,
            query_stride as c_long,
            key_stride as c_long,
            internal_type,
        )
    }
    Ok(())
}