fn get_backend_model_type()

in router/src/lib.rs [350:409]


fn get_backend_model_type(
    config: &ModelConfig,
    model_root: &Path,
    pooling: Option<text_embeddings_backend::Pool>,
) -> Result<text_embeddings_backend::ModelType> {
    for arch in &config.architectures {
        // Edge case affecting `Alibaba-NLP/gte-multilingual-base` and possibly other fine-tunes of
        // the same base model. More context at https://huggingface.co/Alibaba-NLP/gte-multilingual-base/discussions/7
        if arch == "NewForTokenClassification"
            && (config.id2label.is_none() | config.label2id.is_none())
        {
            tracing::warn!("Provided `--model-id` is likely an AlibabaNLP GTE model, but the `config.json` contains the architecture `NewForTokenClassification` but it doesn't contain the `id2label` and `label2id` mapping, so `NewForTokenClassification` architecture will be ignored.");
            continue;
        }

        if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") {
            return Ok(text_embeddings_backend::ModelType::Embedding(
                text_embeddings_backend::Pool::Splade,
            ));
        } else if arch.ends_with("Classification") {
            if pooling.is_some() {
                tracing::warn!(
                    "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."
                );
            }
            return Ok(text_embeddings_backend::ModelType::Classifier);
        }
    }

    if Some(text_embeddings_backend::Pool::Splade) == pooling {
        return Err(anyhow!(
            "Splade pooling is not supported: model is not a ForMaskedLM model"
        ));
    }

    // Set pooling
    let pool = match pooling {
        Some(pool) => pool,
        None => {
            // Load pooling config
            let config_path = model_root.join("1_Pooling/config.json");

            match fs::read_to_string(config_path) {
                Ok(config) => {
                    let config: PoolConfig = serde_json::from_str(&config)
                        .context("Failed to parse `1_Pooling/config.json`")?;
                    Pool::try_from(config)?
                }
                Err(err) => {
                    if !config.model_type.to_lowercase().contains("bert") {
                        return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
                    }
                    tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
                    text_embeddings_backend::Pool::Cls
                }
            }
        }
    };
    Ok(text_embeddings_backend::ModelType::Embedding(pool))
}