fn main()

in launcher/src/main.rs [2034:2358]


fn main() -> Result<(), LauncherError> {
    // Pattern match configuration
    let args: Args = Args::parse();

    let graceful_termination_timeout = args.graceful_termination_timeout;

    // Filter events with LOG_LEVEL
    let varname = "LOG_LEVEL";
    let env_filter = if let Ok(log_level) = std::env::var(varname) {
        // Override to avoid simple logs to be spammed with tokio level informations
        let log_level = match &log_level[..] {
            "warn" => "text_generation_launcher=warn,text_generation_router=warn",
            "info" => "text_generation_launcher=info,text_generation_router=info",
            "debug" => "text_generation_launcher=debug,text_generation_router=debug",
            log_level => log_level,
        };
        EnvFilter::builder()
            .with_default_directive(LevelFilter::INFO.into())
            .parse_lossy(log_level)
    } else {
        EnvFilter::new("info")
    };
    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);

    if args.json_output {
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .json()
            .init();
    } else {
        tracing_subscriber::fmt()
            .with_env_filter(env_filter)
            .compact()
            .init();
    }

    if args.env {
        let env_runtime = env_runtime::Env::new();
        tracing::info!("{}", env_runtime);
    }

    tracing::info!("{:#?}", args);

    let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
    let quantize = config.as_ref().and_then(|c| c.quantize);
    // Quantization usually means you're even more RAM constrained.

    let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
    tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
    std::env::set_var("PREFIX_CACHING", prefix_caching);
    std::env::set_var("ATTENTION", attention);

    let num_shard = find_num_shards(args.sharded, args.num_shard)?;
    if num_shard > 1 {
        if matches!(args.quantize, Some(Quantization::Exl2)) {
            return Err(LauncherError::ArgumentValidation(
                "Sharding is currently not supported with `exl2` quantization".into(),
            ));
        }
        tracing::info!("Sharding model on {num_shard} processes");
    }

    let max_input_tokens = {
        match (args.max_input_tokens, args.max_input_length) {
            (Some(max_input_tokens), Some(max_input_length)) => {
                return Err(LauncherError::ArgumentValidation(
                    format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
                )));
            }
            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
                Some(max_input_tokens)
            }
            (None, None) => None,
        }
    };
    let max_total_tokens = args.max_total_tokens;
    let max_batch_prefill_tokens = {
        match args.max_batch_prefill_tokens {
            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
            None => {
                let compute_type = compute_type(num_shard);
                let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
                // TODO: remove this when we correctly esimate the flops for VLMs
                // this is a short term temporary fix to enable vlms to avoid rejecting images
                let default_optimal = match config {
                    Some(ref config) => match config.model_type.as_deref() {
                        Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
                        Some("gemma3") => 8000,
                        _ => 4096,
                    },
                    None => 4096,
                };
                let default = compute_optimal.unwrap_or(default_optimal);
                let vram_maximum = vram_maximum(
                    config.as_ref(),
                    compute_type.as_ref(),
                    args.cuda_memory_fraction,
                );
                let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
                let value = if let Some(max_position_embeddings) = max_position_embeddings {
                    default.min(max_position_embeddings)
                } else {
                    default
                };
                let value = if let Some(vram_maximum) = vram_maximum {
                    if vram_maximum < value {
                        tracing::warn!("Reducing the max batch prefill from {default} to {vram_maximum} because there is not enough VRAM to support it.");
                    }
                    value.min(vram_maximum)
                } else {
                    value
                };
                tracing::info!("Default `max_batch_prefill_tokens` to {value}");
                value as u32
            }
        }
    };

    // Validate args
    if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
        if max_input_tokens >= max_total_tokens {
            return Err(LauncherError::ArgumentValidation(
                    format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
                ));
        }
    }

    if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
        tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
    }
    let quantize = args.quantize.or(quantize);
    let cuda_graphs = match (&args.cuda_graphs, &quantize) {
        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
        #[allow(deprecated)]
        (None, Some(Quantization::Bitsandbytes)) => {
            tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        (None, Some(Quantization::Exl2)) => {
            tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
            vec![]
        }
        _ => {
            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
            tracing::info!("Using default cuda graphs {cuda_graphs:?}");
            cuda_graphs
        }
    };

    if args.validation_workers == 0 {
        return Err(LauncherError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
    }
    if args.trust_remote_code {
        tracing::warn!(
            "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
            args.model_id
        );
    }

    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
        if let Some(max_total_tokens) = max_total_tokens {
            if max_total_tokens as u32 > *max_batch_total_tokens {
                return Err(LauncherError::ArgumentValidation(format!(
                    "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
                    max_total_tokens, max_batch_total_tokens
                )));
            }
        }
    }

    if args.ngrok {
        if args.ngrok_authtoken.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
            ));
        }

        if args.ngrok_edge.is_none() {
            return Err(LauncherError::ArgumentValidation(
                "`ngrok-edge` must be set when using ngrok tunneling".to_string(),
            ));
        }
    }

    // Signal handler
    let running = Arc::new(AtomicBool::new(true));
    let r = running.clone();
    ctrlc::set_handler(move || {
        r.store(false, Ordering::SeqCst);
    })
    .expect("Error setting Ctrl-C handler");

    // Download and convert model weights
    download_convert_model(
        &args.model_id,
        args.revision.as_deref(),
        args.trust_remote_code,
        args.huggingface_hub_cache.as_deref(),
        args.weights_cache_override.as_deref(),
        running.clone(),
        true, // if its only a lora model - we should merge the lora adapters
    )?;

    // Download and convert lora adapters if any
    if let Some(lora_adapters) = &args.lora_adapters {
        for adapter in lora_adapters.split(',') {
            // skip download if a path is provided
            if adapter.contains('=') {
                continue;
            }

            let adapter = adapter.trim();

            // check if adapter has more than 1 '@'
            if adapter.matches('@').count() > 1 {
                return Err(LauncherError::ArgumentValidation(format!(
                    "Invalid LoRA adapter format: {}",
                    adapter
                )));
            }

            // capture adapter_id, path, revision in format of adapter_id=path@revision
            // path is disabled beforehand.
            let mut splits = adapter.split("@");
            let adapter_id = splits.next().ok_or_else(|| {
                LauncherError::ArgumentValidation("Missing adapter id".to_string())
            })?;
            let revision = splits.next();
            download_convert_model(
                adapter_id,
                revision,
                args.trust_remote_code,
                args.huggingface_hub_cache.as_deref(),
                args.weights_cache_override.as_deref(),
                running.clone(),
                false, // avoid merging lora adapters if using multi-lora
            )?;
        }
    }

    if !running.load(Ordering::SeqCst) {
        // Launcher was asked to stop
        return Ok(());
    }

    // Shared shutdown bool
    let shutdown = Arc::new(AtomicBool::new(false));
    // Shared shutdown channel
    // When shutting down, the main thread will wait for all senders to be dropped
    let (shutdown_sender, shutdown_receiver) = mpsc::channel();

    // Shared channel to track shard status
    let (status_sender, status_receiver) = mpsc::channel();

    spawn_shards(
        num_shard,
        &args,
        cuda_graphs,
        max_total_tokens,
        max_input_tokens,
        quantize,
        max_log_level,
        shutdown.clone(),
        &shutdown_receiver,
        shutdown_sender,
        &status_receiver,
        status_sender,
        running.clone(),
        graceful_termination_timeout,
    )?;

    // We might have received a termination signal
    if !running.load(Ordering::SeqCst) {
        shutdown_shards(shutdown, &shutdown_receiver);
        return Ok(());
    }

    let mut webserver = spawn_webserver(
        num_shard,
        args,
        max_input_tokens,
        max_total_tokens,
        max_batch_prefill_tokens,
        shutdown.clone(),
        &shutdown_receiver,
    )
    .inspect_err(|_| {
        shutdown_shards(shutdown.clone(), &shutdown_receiver);
    })?;

    // Default exit code
    let mut exit_code = Ok(());

    while running.load(Ordering::SeqCst) {
        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
            tracing::error!("Shard {rank} crashed");
            exit_code = Err(LauncherError::ShardFailed);
            break;
        };

        match webserver.try_wait().unwrap() {
            Some(_) => {
                tracing::error!("Webserver Crashed");
                shutdown_shards(shutdown, &shutdown_receiver);
                return Err(LauncherError::WebserverFailed);
            }
            None => {
                sleep(Duration::from_millis(100));
            }
        };
    }

    // Graceful termination
    terminate(
        "webserver",
        webserver,
        Duration::from_secs(graceful_termination_timeout),
    )
    .unwrap();
    shutdown_shards(shutdown, &shutdown_receiver);

    exit_code
}