in crates/ratchet-core/src/cpu/rope.rs [63:143]
fn rope(
src: Vec<f32>,
shape: &Shape,
dim: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();
let half_dim = dim / 2;
let theta = compute_theta(dim, seq_len, base, offset)?;
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
let src_strides = Strides::from(shape);
let x1 = slice(
&src,
&src_strides,
&[0, 0, 0, 0],
&[batches, num_heads, seq_len, half_dim],
);
let x2 = slice(
&src,
&src_strides,
&[0, 0, 0, half_dim],
&[batches, num_heads, seq_len, dim],
);
//`multiply` as an operation that deals with broadcasting
let x1_cos = x1
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let x2_sin = x2
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();
let mut r1 = x1_cos
.iter()
.zip(x2_sin.iter())
.map(|(x1, x2)| x1 - x2)
.collect::<Vec<f32>>();
r1.extend(vec![0.0; shape.numel() - r1.len()]);
let x1_sin = x1
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();
let x2_cos = x2
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let mut r2 = x1_sin
.iter()
.zip(x2_cos.iter())
.map(|(x1, x2)| x1 + x2)
.collect::<Vec<f32>>();
r2.extend(vec![0.0; shape.numel() - r2.len()]);
let mut to_cat = vec![
(shape![batches, num_heads, seq_len, half_dim], r1),
(shape![batches, num_heads, seq_len, half_dim], r2),
];
if dim < shape[3] {
let r3 = slice(
&src,
&src_strides,
&[0, 0, 0, dim],
&[batches, num_heads, seq_len, head_dim],
);
to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3));
}
let dst_shape = shape![batches, num_heads, seq_len, head_dim];
let mut dst = vec![0.0f32; dst_shape.numel()];
concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?;
Ok(dst)
}