fn load_inner()

in crates/ratchet-models/src/moondream/model.rs [49:198]


    fn load_inner<F>(_header: &Header, mut lt: F, device: &Device) -> anyhow::Result<Self>
    where
        F: FnMut(&str) -> Tensor,
    {
        let n_layers = 24_i32;
        let dim = 2048_f32;
        let n_heads = 32_u32;
        let n_kv_heads = 32_u32;
        let rope_base = 10000.0f32;
        let rope_dim = 32_u32;
        let ln_eps = 1e-05;
        let hdim = dim / n_heads as f32;
        let softmax_scale = Tensor::from_data([1.0 / hdim.sqrt()], shape![1], device.clone());
        let cache_shape = shape![1, 32, 4096, 64];

        let kv_cache = match device.compute_precision() {
            DType::F16 => KVCache::new::<f16>(n_layers as _, cache_shape, device),
            DType::F32 => KVCache::new::<f32>(n_layers as _, cache_shape, device),
            _ => unimplemented!(),
        };

        let text_model = TextModel::new(
            Embedding::new(lt("text_model.transformer.embd.wte.weight")),
            (0..n_layers)
                .map(|i| {
                    DecoderLayer::new(
                        LayerNorm::new(
                            lt(&format!("text_model.transformer.h.{}.ln.weight", i)),
                            Some(lt(&format!("text_model.transformer.h.{}.ln.bias", i))),
                            ln_eps,
                        ),
                        SelfAttention::new(
                            Linear::new(
                                lt(&format!("text_model.transformer.h.{}.mixer.Wqkv.weight", i)),
                                Some(lt(&format!(
                                    "text_model.transformer.h.{}.mixer.Wqkv.bias",
                                    i
                                ))),
                            ),
                            Linear::new(
                                lt(&format!(
                                    "text_model.transformer.h.{}.mixer.out_proj.weight",
                                    i
                                )),
                                Some(lt(&format!(
                                    "text_model.transformer.h.{}.mixer.out_proj.bias",
                                    i
                                ))),
                            ),
                            RotaryEmbedding::new(rope_dim as usize, false, rope_base, 1.0),
                            n_heads,
                            softmax_scale.clone(),
                            n_kv_heads,
                        ),
                        MLP::new(
                            Linear::new(
                                lt(&format!("text_model.transformer.h.{}.mlp.fc1.weight", i)),
                                Some(lt(&format!("text_model.transformer.h.{}.mlp.fc1.bias", i))),
                            ),
                            Linear::new(
                                lt(&format!("text_model.transformer.h.{}.mlp.fc2.weight", i)),
                                Some(lt(&format!("text_model.transformer.h.{}.mlp.fc2.bias", i))),
                            ),
                        ),
                    )
                })
                .collect(),
            LayerNorm::new(
                lt("text_model.lm_head.ln.weight"),
                Some(lt("text_model.lm_head.ln.bias")),
                ln_eps,
            ),
            Linear::new(
                lt("text_model.lm_head.linear.weight"),
                Some(lt("text_model.lm_head.linear.bias")),
            ),
            kv_cache,
            device.clone(),
        );

        let projection = VisionProjection::new(MLP::new(
            Linear::new(
                lt("vision_encoder.projection.mlp.fc1.weight"),
                Some(lt("vision_encoder.projection.mlp.fc1.bias")),
            ),
            Linear::new(
                lt("vision_encoder.projection.mlp.fc2.weight"),
                Some(lt("vision_encoder.projection.mlp.fc2.bias")),
            ),
        ));

        let transformer = VisionTransformer::new(
            LinearPatchEmbedding::new(
                Linear::new(lt("vision_encoder.encoder.model.visual.patch_embed.linear.weight"), Some(lt("vision_encoder.encoder.model.visual.patch_embed.linear.bias"))),
            ),
            lt("vision_encoder.encoder.model.visual.pos_embed"),
            (0..27)
                .map(|layer| {
                    let qkvw = lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.qkv.weight", layer));
                    let qkvb = lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.qkv.bias", layer));

                    let n_heads = 16;
                    let dim = 1152;
                    let h_dim = dim / n_heads;
                    let scale_factor =
                        Tensor::from_data([1.0 / (h_dim as f32).sqrt()], shape![1], device.clone());

                    VitBlock::new(
                        1152,
                        Attention::new(
                            n_heads,
                            dim,
                            Linear::new(qkvw, Some(qkvb)),
                            Linear::new(
                                lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.proj.weight", layer)),
                                Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.proj.bias", layer))),
                            ),
                            scale_factor,
                        ),
                        MLP::new(
                            Linear::new(
                                lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc1.weight", layer)),
                                Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc1.bias", layer))),
                            ),
                            Linear::new(
                                lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc2.weight", layer)),
                                Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc2.bias", layer))),
                            ),
                        ),
                        LayerNorm::new(
                            lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm1.weight", layer)),
                            Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm1.bias", layer))),
                            ln_eps,
                        ),
                        LayerNorm::new(
                            lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm2.weight", layer)),
                            Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm2.bias", layer))),
                            ln_eps,
                        ),
                    )
                }).collect::<Vec<_>>(),
            LayerNorm::new(lt("vision_encoder.encoder.model.visual.norm.weight"), Some(lt("vision_encoder.encoder.model.visual.norm.bias")), ln_eps),
        );

        let vision_encoder = VisionEncoder::new(projection, transformer);
        Ok(Self {
            vision_encoder,
            text_model,
        })
    }