in launcher/src/main.rs [2034:2358]
fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args: Args = Args::parse();
let graceful_termination_timeout = args.graceful_termination_timeout;
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
if args.json_output {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.json()
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.compact()
.init();
}
if args.env {
let env_runtime = env_runtime::Env::new();
tracing::info!("{}", env_runtime);
}
tracing::info!("{:#?}", args);
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
return Err(LauncherError::ArgumentValidation(
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
Some(max_input_tokens)
}
(None, None) => None,
}
};
let max_total_tokens = args.max_total_tokens;
let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let compute_type = compute_type(num_shard);
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
// TODO: remove this when we correctly esimate the flops for VLMs
// this is a short term temporary fix to enable vlms to avoid rejecting images
let default_optimal = match config {
Some(ref config) => match config.model_type.as_deref() {
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
Some("gemma3") => 8000,
_ => 4096,
},
None => 4096,
};
let default = compute_optimal.unwrap_or(default_optimal);
let vram_maximum = vram_maximum(
config.as_ref(),
compute_type.as_ref(),
args.cuda_memory_fraction,
);
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
let value = if let Some(max_position_embeddings) = max_position_embeddings {
default.min(max_position_embeddings)
} else {
default
};
let value = if let Some(vram_maximum) = vram_maximum {
if vram_maximum < value {
tracing::warn!("Reducing the max batch prefill from {default} to {vram_maximum} because there is not enough VRAM to support it.");
}
value.min(vram_maximum)
} else {
value
};
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value as u32
}
}
};
// Validate args
if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
));
}
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
}
let quantize = args.quantize.or(quantize);
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)]
(None, Some(Quantization::Bitsandbytes)) => {
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
(None, Some(Quantization::Exl2)) => {
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
}
};
if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if let Some(max_total_tokens) = max_total_tokens {
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_total_tokens, max_batch_total_tokens
)));
}
}
}
if args.ngrok {
if args.ngrok_authtoken.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
));
}
if args.ngrok_edge.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
));
}
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Download and convert model weights
download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
true, // if its only a lora model - we should merge the lora adapters
)?;
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
// skip download if a path is provided
if adapter.contains('=') {
continue;
}
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
// path is disabled beforehand.
let mut splits = adapter.split("@");
let adapter_id = splits.next().ok_or_else(|| {
LauncherError::ArgumentValidation("Missing adapter id".to_string())
})?;
let revision = splits.next();
download_convert_model(
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
false, // avoid merging lora adapters if using multi-lora
)?;
}
}
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
return Ok(());
}
// Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
spawn_shards(
num_shard,
&args,
cuda_graphs,
max_total_tokens,
max_input_tokens,
quantize,
max_log_level,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
&status_receiver,
status_sender,
running.clone(),
graceful_termination_timeout,
)?;
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return Ok(());
}
let mut webserver = spawn_webserver(
num_shard,
args,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
shutdown.clone(),
&shutdown_receiver,
)
.inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
})?;
// Default exit code
let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} crashed");
exit_code = Err(LauncherError::ShardFailed);
break;
};
match webserver.try_wait().unwrap() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return Err(LauncherError::WebserverFailed);
}
None => {
sleep(Duration::from_millis(100));
}
};
}
// Graceful termination
terminate(
"webserver",
webserver,
Duration::from_secs(graceful_termination_timeout),
)
.unwrap();
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}