fn load_inner()

in crates/ratchet-models/src/whisper/residual_block.rs [89:160]


    fn load_inner<F>(mut lt: F, prefix: &str, n_heads: usize) -> anyhow::Result<Self>
    where
        F: FnMut(&str) -> anyhow::Result<Tensor>,
    {
        let attn_ln = LayerNorm::new(
            lt("self_attn_layer_norm.weight")?,
            Some(lt("self_attn_layer_norm.bias")?),
            1e-5,
        );
        //model.encoder.layers.0.self_attn.v_proj.weight
        let attn = MultiHeadAttention::new(
            Linear::new(
                lt("self_attn.q_proj.weight")?,
                Some(lt("self_attn.q_proj.bias")?),
            ),
            Linear::new(lt("self_attn.k_proj.weight")?, None),
            Linear::new(
                lt("self_attn.v_proj.weight")?,
                Some(lt("self_attn.v_proj.bias")?),
            ),
            Linear::new(
                lt("self_attn.out_proj.weight")?,
                Some(lt("self_attn.out_proj.bias")?),
            ),
            n_heads,
        );

        let (x_attn_ln, x_attn) = if prefix == "decoder" {
            let x_attn_ln = LayerNorm::new(
                lt("encoder_attn_layer_norm.weight")?,
                Some(lt("encoder_attn_layer_norm.bias")?),
                1e-5,
            );
            let x_attn = MultiHeadAttention::new(
                Linear::new(
                    lt("encoder_attn.q_proj.weight")?,
                    Some(lt("encoder_attn.q_proj.bias")?),
                ),
                Linear::new(lt("encoder_attn.k_proj.weight")?, None),
                Linear::new(
                    lt("encoder_attn.v_proj.weight")?,
                    Some(lt("encoder_attn.v_proj.bias")?),
                ),
                Linear::new(
                    lt("encoder_attn.out_proj.weight")?,
                    Some(lt("encoder_attn.out_proj.bias")?),
                ),
                n_heads,
            );
            (Some(x_attn_ln), Some(x_attn))
        } else {
            (None, None)
        };

        let mlp_ln = LayerNorm::new(
            lt("final_layer_norm.weight")?,
            Some(lt("final_layer_norm.bias")?),
            1e-5,
        );
        let mlp = MLP::new(
            Linear::new(lt("fc1.weight")?, Some(lt("fc1.bias")?)),
            Linear::new(lt("fc2.weight")?, Some(lt("fc2.bias")?)),
        );
        Ok(Self {
            attn_ln,
            attn,
            x_attn_ln,
            x_attn,
            mlp_ln,
            mlp,
        })
    }