in backends/candle/src/models/gte.rs [442:475]
fn inner_load(
vb: VarBuilder,
config: >EConfig,
) -> Result<(Embedding, Option<Embedding>, GTEEncoder, LayerNorm)> {
let word_embeddings = Embedding::new(
vb.pp("embeddings.word_embeddings")
.get((config.vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
);
let token_type_embeddings = if config.type_vocab_size > 0 {
Some(Embedding::new(
vb.pp("embeddings.token_type_embeddings")
.get((config.type_vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
))
} else {
None
};
let encoder = GTEEncoder::load(vb.pp("encoder"), config)?;
let embeddings_norm = LayerNorm::load(
vb.pp("embeddings.LayerNorm"),
config.hidden_size,
config.layer_norm_eps,
)?;
Ok((
word_embeddings,
token_type_embeddings,
encoder,
embeddings_norm,
))
}