fn rope()

in crates/ratchet-core/src/cpu/rope.rs [63:143]


fn rope(
    src: Vec<f32>,
    shape: &Shape,
    dim: usize,
    base: f32,
    offset: usize,
) -> Result<Vec<f32>, OperationError> {
    let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();

    let half_dim = dim / 2;
    let theta = compute_theta(dim, seq_len, base, offset)?;
    let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
    let src_strides = Strides::from(shape);
    let x1 = slice(
        &src,
        &src_strides,
        &[0, 0, 0, 0],
        &[batches, num_heads, seq_len, half_dim],
    );
    let x2 = slice(
        &src,
        &src_strides,
        &[0, 0, 0, half_dim],
        &[batches, num_heads, seq_len, dim],
    );

    //`multiply` as an operation that deals with broadcasting
    let x1_cos = x1
        .iter()
        .zip(cos.iter().cycle())
        .map(|(x, c)| x * c)
        .collect::<Vec<f32>>();
    let x2_sin = x2
        .iter()
        .zip(sin.iter().cycle())
        .map(|(x, s)| x * s)
        .collect::<Vec<f32>>();

    let mut r1 = x1_cos
        .iter()
        .zip(x2_sin.iter())
        .map(|(x1, x2)| x1 - x2)
        .collect::<Vec<f32>>();
    r1.extend(vec![0.0; shape.numel() - r1.len()]);

    let x1_sin = x1
        .iter()
        .zip(sin.iter().cycle())
        .map(|(x, s)| x * s)
        .collect::<Vec<f32>>();
    let x2_cos = x2
        .iter()
        .zip(cos.iter().cycle())
        .map(|(x, c)| x * c)
        .collect::<Vec<f32>>();
    let mut r2 = x1_sin
        .iter()
        .zip(x2_cos.iter())
        .map(|(x1, x2)| x1 + x2)
        .collect::<Vec<f32>>();
    r2.extend(vec![0.0; shape.numel() - r2.len()]);

    let mut to_cat = vec![
        (shape![batches, num_heads, seq_len, half_dim], r1),
        (shape![batches, num_heads, seq_len, half_dim], r2),
    ];
    if dim < shape[3] {
        let r3 = slice(
            &src,
            &src_strides,
            &[0, 0, 0, dim],
            &[batches, num_heads, seq_len, head_dim],
        );
        to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3));
    }

    let dst_shape = shape![batches, num_heads, seq_len, head_dim];
    let mut dst = vec![0.0f32; dst_shape.numel()];
    concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?;
    Ok(dst)
}