fn spawn_webserver()

in launcher/src/main.rs [1827:2009]


fn spawn_webserver(
    num_shard: usize,
    args: Args,
    max_input_tokens: Option<usize>,
    max_total_tokens: Option<usize>,
    max_batch_prefill_tokens: u32,
    shutdown: Arc<AtomicBool>,
    shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> {
    // All shard started
    // Start webserver
    tracing::info!("Starting Webserver");
    let mut router_args = vec![
        "--max-client-batch-size".to_string(),
        args.max_client_batch_size.to_string(),
        "--max-concurrent-requests".to_string(),
        args.max_concurrent_requests.to_string(),
        "--max-best-of".to_string(),
        args.max_best_of.to_string(),
        "--max-stop-sequences".to_string(),
        args.max_stop_sequences.to_string(),
        "--max-top-n-tokens".to_string(),
        args.max_top_n_tokens.to_string(),
        "--max-batch-prefill-tokens".to_string(),
        max_batch_prefill_tokens.to_string(),
        "--waiting-served-ratio".to_string(),
        args.waiting_served_ratio.to_string(),
        "--max-waiting-tokens".to_string(),
        args.max_waiting_tokens.to_string(),
        "--validation-workers".to_string(),
        args.validation_workers.to_string(),
        "--hostname".to_string(),
        args.hostname.to_string(),
        "--port".to_string(),
        args.port.to_string(),
        "--prometheus-port".to_string(),
        args.prometheus_port.to_string(),
        "--master-shard-uds-path".to_string(),
        format!("{}-0", args.shard_uds_path),
        "--tokenizer-name".to_string(),
        args.model_id,
        "--payload-limit".to_string(),
        args.payload_limit.to_string(),
    ];
    if let Some(max_input_tokens) = max_input_tokens {
        router_args.extend_from_slice(&[
            "--max-input-tokens".to_string(),
            max_input_tokens.to_string(),
        ]);
    }
    if let Some(max_total_tokens) = max_total_tokens {
        router_args.extend_from_slice(&[
            "--max-total-tokens".to_string(),
            max_total_tokens.to_string(),
        ]);
    }

    // Pass usage stats flags to router
    router_args.push("--usage-stats".to_string());
    router_args.push(args.usage_stats.to_string());

    // Grammar support
    if args.disable_grammar_support {
        router_args.push("--disable-grammar-support".to_string());
    }

    // Tokenizer config path
    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
        router_args.push("--tokenizer-config-path".to_string());
        router_args.push(tokenizer_config_path.to_string());
    }

    // Model optional max batch total tokens
    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
        router_args.push("--max-batch-total-tokens".to_string());
        router_args.push(max_batch_total_tokens.to_string());
    }

    // Router optional max batch size
    if let Some(max_batch_size) = args.max_batch_size {
        router_args.push("--max-batch-size".to_string());
        router_args.push(max_batch_size.to_string());
    }

    // Model optional revision
    if let Some(ref revision) = args.revision {
        router_args.push("--revision".to_string());
        router_args.push(revision.to_string())
    }

    if args.trust_remote_code {
        router_args.push("--trust-remote-code".to_string());
    }

    if args.json_output {
        router_args.push("--json-output".to_string());
    }

    // OpenTelemetry
    if let Some(otlp_endpoint) = args.otlp_endpoint {
        router_args.push("--otlp-endpoint".to_string());
        router_args.push(otlp_endpoint);
    }

    // OpenTelemetry
    let otlp_service_name = args.otlp_service_name;
    router_args.push("--otlp-service-name".to_string());
    router_args.push(otlp_service_name);

    // CORS origins
    for origin in args.cors_allow_origin.into_iter() {
        router_args.push("--cors-allow-origin".to_string());
        router_args.push(origin);
    }

    // API Key
    if let Some(api_key) = args.api_key {
        router_args.push("--api-key".to_string());
        router_args.push(api_key);
    }
    // Ngrok
    if args.ngrok {
        router_args.push("--ngrok".to_string());
        router_args.push("--ngrok-authtoken".to_string());
        router_args.push(args.ngrok_authtoken.unwrap());
        router_args.push("--ngrok-edge".to_string());
        router_args.push(args.ngrok_edge.unwrap());
    }

    // Copy current process env
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
        envs.push(("HF_TOKEN".into(), api_token.into()))
    };

    // Parse Compute type
    if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    } else if let Some(compute_type) = compute_type(num_shard) {
        envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
    }

    let mut webserver = match Command::new("text-generation-router")
        .args(router_args)
        .envs(envs)
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
        Ok(p) => p,
        Err(err) => {
            tracing::error!("Failed to start webserver: {}", err);
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-router not found in PATH");
                tracing::error!("Please install it with `make install-router`")
            } else {
                tracing::error!("{}", err);
            }

            shutdown_shards(shutdown, shutdown_receiver);
            return Err(LauncherError::WebserverCannotStart);
        }
    };

    // Redirect STDOUT and STDERR to the console
    let webserver_stdout = webserver.stdout.take().unwrap();
    let webserver_stderr = webserver.stderr.take().unwrap();

    thread::spawn(move || {
        let stdout = BufReader::new(webserver_stdout);
        let stderr = BufReader::new(webserver_stderr);
        for line in stdout.lines() {
            println!("{}", line.unwrap());
        }
        for line in stderr.lines() {
            println!("{}", line.unwrap());
        }
    });
    Ok(webserver)
}