fn load_inner()

in crates/ratchet-models/src/phi2/attn.rs [53:89]


    fn load_inner<F>(header: &Header, mut lt: F, device: &Device) -> anyhow::Result<Self>
    where
        F: FnMut(&str) -> anyhow::Result<Tensor>,
    {
        let q = Linear::new(lt("attn_q.weight")?, Some(lt("attn_q.bias")?));
        let k = Linear::new(lt("attn_k.weight")?, Some(lt("attn_k.bias")?));
        let v = Linear::new(lt("attn_v.weight")?, Some(lt("attn_v.bias")?));
        let o = Linear::new(lt("attn_output.weight")?, Some(lt("attn_output.bias")?));

        let n_heads = header
            .metadata
            .get("phi2.attention.head_count")
            .unwrap()
            .to_u32()?;
        let n_kv_heads = header
            .metadata
            .get("phi2.attention.head_count_kv")
            .unwrap()
            .to_u32()?;

        let scale_val = 1.0 / 80_f32.sqrt();
        let softmax_scale = Tensor::from_data([scale_val], shape![1], device.clone());
        //TODO: hardcoded for Phi2, should read from meta
        let base = 10000.0;
        let dim = (0.4 * (2560f64 / 32f64)) as usize;
        let rope = RotaryEmbedding::new(dim, false, base, 1.0);
        Ok(Self {
            q,
            k,
            v,
            o,
            rope,
            n_heads,
            softmax_scale,
            n_kv_heads,
        })
    }