in crates/ratchet-models/src/moondream/model.rs [49:198]
fn load_inner<F>(_header: &Header, mut lt: F, device: &Device) -> anyhow::Result<Self>
where
F: FnMut(&str) -> Tensor,
{
let n_layers = 24_i32;
let dim = 2048_f32;
let n_heads = 32_u32;
let n_kv_heads = 32_u32;
let rope_base = 10000.0f32;
let rope_dim = 32_u32;
let ln_eps = 1e-05;
let hdim = dim / n_heads as f32;
let softmax_scale = Tensor::from_data([1.0 / hdim.sqrt()], shape![1], device.clone());
let cache_shape = shape![1, 32, 4096, 64];
let kv_cache = match device.compute_precision() {
DType::F16 => KVCache::new::<f16>(n_layers as _, cache_shape, device),
DType::F32 => KVCache::new::<f32>(n_layers as _, cache_shape, device),
_ => unimplemented!(),
};
let text_model = TextModel::new(
Embedding::new(lt("text_model.transformer.embd.wte.weight")),
(0..n_layers)
.map(|i| {
DecoderLayer::new(
LayerNorm::new(
lt(&format!("text_model.transformer.h.{}.ln.weight", i)),
Some(lt(&format!("text_model.transformer.h.{}.ln.bias", i))),
ln_eps,
),
SelfAttention::new(
Linear::new(
lt(&format!("text_model.transformer.h.{}.mixer.Wqkv.weight", i)),
Some(lt(&format!(
"text_model.transformer.h.{}.mixer.Wqkv.bias",
i
))),
),
Linear::new(
lt(&format!(
"text_model.transformer.h.{}.mixer.out_proj.weight",
i
)),
Some(lt(&format!(
"text_model.transformer.h.{}.mixer.out_proj.bias",
i
))),
),
RotaryEmbedding::new(rope_dim as usize, false, rope_base, 1.0),
n_heads,
softmax_scale.clone(),
n_kv_heads,
),
MLP::new(
Linear::new(
lt(&format!("text_model.transformer.h.{}.mlp.fc1.weight", i)),
Some(lt(&format!("text_model.transformer.h.{}.mlp.fc1.bias", i))),
),
Linear::new(
lt(&format!("text_model.transformer.h.{}.mlp.fc2.weight", i)),
Some(lt(&format!("text_model.transformer.h.{}.mlp.fc2.bias", i))),
),
),
)
})
.collect(),
LayerNorm::new(
lt("text_model.lm_head.ln.weight"),
Some(lt("text_model.lm_head.ln.bias")),
ln_eps,
),
Linear::new(
lt("text_model.lm_head.linear.weight"),
Some(lt("text_model.lm_head.linear.bias")),
),
kv_cache,
device.clone(),
);
let projection = VisionProjection::new(MLP::new(
Linear::new(
lt("vision_encoder.projection.mlp.fc1.weight"),
Some(lt("vision_encoder.projection.mlp.fc1.bias")),
),
Linear::new(
lt("vision_encoder.projection.mlp.fc2.weight"),
Some(lt("vision_encoder.projection.mlp.fc2.bias")),
),
));
let transformer = VisionTransformer::new(
LinearPatchEmbedding::new(
Linear::new(lt("vision_encoder.encoder.model.visual.patch_embed.linear.weight"), Some(lt("vision_encoder.encoder.model.visual.patch_embed.linear.bias"))),
),
lt("vision_encoder.encoder.model.visual.pos_embed"),
(0..27)
.map(|layer| {
let qkvw = lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.qkv.weight", layer));
let qkvb = lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.qkv.bias", layer));
let n_heads = 16;
let dim = 1152;
let h_dim = dim / n_heads;
let scale_factor =
Tensor::from_data([1.0 / (h_dim as f32).sqrt()], shape![1], device.clone());
VitBlock::new(
1152,
Attention::new(
n_heads,
dim,
Linear::new(qkvw, Some(qkvb)),
Linear::new(
lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.proj.weight", layer)),
Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.attn.proj.bias", layer))),
),
scale_factor,
),
MLP::new(
Linear::new(
lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc1.weight", layer)),
Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc1.bias", layer))),
),
Linear::new(
lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc2.weight", layer)),
Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.mlp.fc2.bias", layer))),
),
),
LayerNorm::new(
lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm1.weight", layer)),
Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm1.bias", layer))),
ln_eps,
),
LayerNorm::new(
lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm2.weight", layer)),
Some(lt(&format!("vision_encoder.encoder.model.visual.blocks.{}.norm2.bias", layer))),
ln_eps,
),
)
}).collect::<Vec<_>>(),
LayerNorm::new(lt("vision_encoder.encoder.model.visual.norm.weight"), Some(lt("vision_encoder.encoder.model.visual.norm.bias")), ln_eps),
);
let vision_encoder = VisionEncoder::new(projection, transformer);
Ok(Self {
vision_encoder,
text_model,
})
}