backends/candle/src/layers/rotary.rs (66 lines of code) (raw):
use candle::{DType, Device, Result, Tensor, D};
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
pub factor: f32,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum RopeScaling {
Ntk(NTKScaling),
}
pub fn get_inv_freqs(
dim: usize,
base: f32,
device: &Device,
rope_scaling: Option<&RopeScaling>,
) -> Result<Tensor> {
let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
};
if let Some(rope_scaling) = rope_scaling {
match rope_scaling {
RopeScaling::Ntk(ntk_scaling) => {
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
return inv_freqs / s;
}
}
}
get_inv_freqs_inner(dim, base, device)
}
pub fn get_cos_sin(
length: usize,
inv_freqs: &Tensor,
dtype: DType,
repeat_freqs: bool,
) -> Result<(Tensor, Tensor)> {
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
.to_dtype(DType::F32)?
.reshape((length, 1))?;
let mut freqs = t.matmul(inv_freqs)?;
if repeat_freqs {
freqs = Tensor::cat(&[&freqs, &freqs], 1)?;
}
let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}
pub fn apply_rotary(
x: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_head_size: usize,
) -> Result<Tensor> {
let dim = attention_head_size / 2;
let x1 = x.narrow(D::Minus1, 0, dim)?;
let x2 = x.narrow(D::Minus1, dim, dim)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?;
Ok(rope)
}