in launcher/src/main.rs [131:200]
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged"
} else {
"flashdecoding"
};
match config.get_head_dim() {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some(fallback_attention.to_string());
}
if fallback_attention == "paged" && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
prefix_caching = Some("0".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some(fallback_attention.to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
(prefix_caching, attention)
}