in router/src/lib.rs [350:409]
fn get_backend_model_type(
config: &ModelConfig,
model_root: &Path,
pooling: Option<text_embeddings_backend::Pool>,
) -> Result<text_embeddings_backend::ModelType> {
for arch in &config.architectures {
// Edge case affecting `Alibaba-NLP/gte-multilingual-base` and possibly other fine-tunes of
// the same base model. More context at https://huggingface.co/Alibaba-NLP/gte-multilingual-base/discussions/7
if arch == "NewForTokenClassification"
&& (config.id2label.is_none() | config.label2id.is_none())
{
tracing::warn!("Provided `--model-id` is likely an AlibabaNLP GTE model, but the `config.json` contains the architecture `NewForTokenClassification` but it doesn't contain the `id2label` and `label2id` mapping, so `NewForTokenClassification` architecture will be ignored.");
continue;
}
if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") {
return Ok(text_embeddings_backend::ModelType::Embedding(
text_embeddings_backend::Pool::Splade,
));
} else if arch.ends_with("Classification") {
if pooling.is_some() {
tracing::warn!(
"`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."
);
}
return Ok(text_embeddings_backend::ModelType::Classifier);
}
}
if Some(text_embeddings_backend::Pool::Splade) == pooling {
return Err(anyhow!(
"Splade pooling is not supported: model is not a ForMaskedLM model"
));
}
// Set pooling
let pool = match pooling {
Some(pool) => pool,
None => {
// Load pooling config
let config_path = model_root.join("1_Pooling/config.json");
match fs::read_to_string(config_path) {
Ok(config) => {
let config: PoolConfig = serde_json::from_str(&config)
.context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?
}
Err(err) => {
if !config.model_type.to_lowercase().contains("bert") {
return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
}
tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
text_embeddings_backend::Pool::Cls
}
}
}
};
Ok(text_embeddings_backend::ModelType::Embedding(pool))
}