fn resolve_attention()

in launcher/src/main.rs [131:200]


fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
    let compute_capability = gpu::get_cuda_capability();
    let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
    let mut attention: Option<String> = std::env::var("ATTENTION").ok();
    if let Some(config) = config {
        if prefix_caching.is_none() {
            if config.vision_config.is_some() {
                tracing::info!("Disabling prefix caching because of VLM model");
                prefix_caching = Some("0".to_string());
            } else if config.is_encoder_decoder {
                tracing::info!("Disabling prefix caching because of seq2seq model");
                prefix_caching = Some("0".to_string());
            }
        }

        let fallback_attention = if compute_capability.is_none()
            || matches!(compute_capability, Some((major, _)) if major < 8)
        {
            "paged"
        } else {
            "flashdecoding"
        };

        match config.get_head_dim() {
            Some(h) if h == 64 || h == 128 || h == 256 => {
                if lora_adapters.is_some() && prefix_caching.is_none() {
                    tracing::info!("Disabling prefix caching because of lora adapters");
                    prefix_caching = Some("0".to_string());
                }
                match config.model_type.as_deref() {
                    Some("falcon") | Some("deepseek_v2") => {
                        // Required because gemma2 needs bfloat16 which is not supported by
                        // flashinfer ?
                        if attention.is_none() {
                            tracing::info!(
                                "Forcing attention to '{fallback_attention}' because model {} requires it",
                                config.model_type.as_ref().unwrap()
                            );
                            attention = Some(fallback_attention.to_string());
                        }
                        if fallback_attention == "paged" && prefix_caching.is_none() {
                            tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
                            prefix_caching = Some("0".to_string());
                        }
                    }
                    Some("t5") => {}
                    _ => {}
                }
            }
            _ => {
                if attention.is_none() {
                    tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
                    attention = Some(fallback_attention.to_string());
                }
                if prefix_caching.is_none() {
                    prefix_caching = Some("0".to_string());
                }
            }
        }
    }
    if attention == Some("paged".to_string()) && prefix_caching.is_none() {
        tracing::info!("Disabling prefix caching on paged attention");
        prefix_caching = Some("0".to_string());
    }

    let attention = attention.unwrap_or("flashinfer".to_string());
    let prefix_caching = prefix_caching.unwrap_or("true".to_string());

    (prefix_caching, attention)
}