fn compute_theta()

in crates/ratchet-core/src/cpu/rope.rs [23:61]


fn compute_theta(
    dim: usize,
    seq_len: usize,
    base: f32,
    offset: usize,
) -> Result<Vec<f32>, OperationError> {
    let half_dim = dim / 2;

    let positions = (offset..seq_len + offset)
        .map(|x| x as f32)
        .collect::<Vec<f32>>();

    let inv_freqs = (0..half_dim)
        .map(|i| -(i as f32))
        .map(|i| i * base.ln() / half_dim as f32)
        .map(f32::exp)
        .collect::<Vec<f32>>();

    let p_shape = shape!(seq_len, 1);
    let p_strides = Strides::from(&p_shape);
    let i_shape = shape!(1, half_dim);
    let i_strides = Strides::from(&i_shape);
    let dst_strides = Strides::from(&shape!(seq_len, half_dim));
    let theta = gemm(
        &positions,
        &p_shape,
        &p_strides,
        &inv_freqs,
        &i_shape,
        &i_strides,
        &dst_strides,
        1,
        seq_len,
        half_dim,
        1,
    )?;

    Ok(theta)
}