in crates/ratchet-models/src/phi2/attn.rs [53:89]
fn load_inner<F>(header: &Header, mut lt: F, device: &Device) -> anyhow::Result<Self>
where
F: FnMut(&str) -> anyhow::Result<Tensor>,
{
let q = Linear::new(lt("attn_q.weight")?, Some(lt("attn_q.bias")?));
let k = Linear::new(lt("attn_k.weight")?, Some(lt("attn_k.bias")?));
let v = Linear::new(lt("attn_v.weight")?, Some(lt("attn_v.bias")?));
let o = Linear::new(lt("attn_output.weight")?, Some(lt("attn_output.bias")?));
let n_heads = header
.metadata
.get("phi2.attention.head_count")
.unwrap()
.to_u32()?;
let n_kv_heads = header
.metadata
.get("phi2.attention.head_count_kv")
.unwrap()
.to_u32()?;
let scale_val = 1.0 / 80_f32.sqrt();
let softmax_scale = Tensor::from_data([scale_val], shape![1], device.clone());
//TODO: hardcoded for Phi2, should read from meta
let base = 10000.0;
let dim = (0.4 * (2560f64 / 32f64)) as usize;
let rope = RotaryEmbedding::new(dim, false, base, 1.0);
Ok(Self {
q,
k,
v,
o,
rope,
n_heads,
softmax_scale,
n_kv_heads,
})
}