in backends/candle/src/models/flash_gte.rs [228:263]
fn inner_load(
vb: VarBuilder,
config: >EConfig,
) -> Result<(Embedding, Option<Embedding>, Vec<GTELayer>, LayerNorm)> {
let word_embeddings = Embedding::new(
vb.pp("embeddings.word_embeddings")
.get((config.vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
);
let token_type_embeddings = if config.type_vocab_size > 0 {
Some(Embedding::new(
vb.pp("embeddings.token_type_embeddings")
.get((config.type_vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
))
} else {
None
};
let layers = (0..config.num_hidden_layers)
.map(|index| GTELayer::load(vb.pp(format!("encoder.layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let embeddings_norm = LayerNorm::load(
vb.pp("embeddings.LayerNorm"),
config.hidden_size,
config.layer_norm_eps,
)?;
Ok((
word_embeddings,
token_type_embeddings,
layers,
embeddings_norm,
))
}