fn shard_manager()

in launcher/src/main.rs [923:1249]


fn shard_manager(
    model_id: String,
    revision: Option<String>,
    quantize: Option<Quantization>,
    speculate: Option<usize>,
    dtype: Option<Dtype>,
    kv_cache_dtype: Option<KVCacheDtype>,
    trust_remote_code: bool,
    uds_path: String,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: usize,
    huggingface_hub_cache: Option<String>,
    weights_cache_override: Option<String>,
    disable_custom_kernels: bool,
    watermark_gamma: Option<f32>,
    watermark_delta: Option<f32>,
    cuda_graphs: Vec<usize>,
    cuda_memory_fraction: f32,
    rope_scaling: Option<RopeScaling>,
    rope_factor: Option<f32>,
    max_total_tokens: Option<usize>,
    max_batch_size: Option<usize>,
    max_input_tokens: Option<usize>,
    lora_adapters: Option<String>,
    enable_prefill_logprobs: bool,
    otlp_endpoint: Option<String>,
    otlp_service_name: String,
    log_level: LevelFilter,
    status_sender: mpsc::Sender<ShardStatus>,
    shutdown: Arc<AtomicBool>,
    graceful_termination_timeout: u64,
    _shutdown_sender: mpsc::Sender<()>,
) {
    // Enter shard-manager tracing span
    let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();

    // Get UDS path
    let uds_string = format!("{uds_path}-{rank}");
    let uds = Path::new(&uds_string);
    // Clean previous runs
    if uds.exists() {
        fs::remove_file(uds).unwrap();
    }

    // Process args
    let mut shard_args = vec![
        "serve".to_string(),
        model_id,
        "--uds-path".to_string(),
        uds_path,
        "--logger-level".to_string(),
        log_level.to_string().to_uppercase(),
        "--json-output".to_string(),
    ];

    // Activate trust remote code
    if trust_remote_code {
        shard_args.push("--trust-remote-code".to_string());
    }

    // Activate tensor parallelism
    if world_size > 1 {
        shard_args.push("--sharded".to_string());
    }

    if let Some(quantize) = quantize {
        shard_args.push("--quantize".to_string());
        shard_args.push(quantize.to_string())
    }

    if let Some(speculate) = speculate {
        shard_args.push("--speculate".to_string());
        shard_args.push(speculate.to_string())
    }

    if let Some(dtype) = dtype {
        shard_args.push("--dtype".to_string());
        shard_args.push(dtype.to_string())
    }

    if let Some(kv_cache_dtype) = kv_cache_dtype {
        shard_args.push("--kv-cache-dtype".to_string());
        shard_args.push(kv_cache_dtype.to_string())
    }

    // Model optional revision
    if let Some(revision) = revision {
        shard_args.push("--revision".to_string());
        shard_args.push(revision)
    }

    let rope = match (rope_scaling, rope_factor) {
        (None, None) => None,
        (Some(scaling), None) => Some((scaling, 1.0)),
        (Some(scaling), Some(factor)) => Some((scaling, factor)),
        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
    };

    // OpenTelemetry Endpoint
    if let Some(otlp_endpoint) = otlp_endpoint {
        shard_args.push("--otlp-endpoint".to_string());
        shard_args.push(otlp_endpoint);
    }

    // OpenTelemetry Service Name
    shard_args.push("--otlp-service-name".to_string());
    shard_args.push(otlp_service_name);

    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
    if let Some(max_input_tokens) = max_input_tokens {
        shard_args.push("--max-input-tokens".to_string());
        shard_args.push(max_input_tokens.to_string());
    }

    // Copy current process env
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

    // Torch Distributed Env vars
    envs.push(("RANK".into(), rank.to_string().into()));
    envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
    envs.push(("MASTER_ADDR".into(), master_addr.into()));
    envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
    envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));

    // CUDA memory fraction
    envs.push((
        "CUDA_MEMORY_FRACTION".into(),
        cuda_memory_fraction.to_string().into(),
    ));

    // Safetensors load fast
    envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
    envs.push((
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
        envs.push(("HF_TOKEN".into(), api_token.into()))
    };

    // Detect rope scaling
    // Sending as env instead of CLI args to not bloat everything
    // those only can be used by RoPE models, so passing information around
    // for all models will complexify code unnecessarily
    if let Some((scaling, factor)) = rope {
        envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
        envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
    }

    if let Some(max_total_tokens) = max_total_tokens {
        envs.push((
            "MAX_TOTAL_TOKENS".into(),
            max_total_tokens.to_string().into(),
        ));
    }
    if let Some(max_batch_size) = max_batch_size {
        envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
    }

    // Lora Adapters
    if let Some(lora_adapters) = lora_adapters {
        envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
    }

    // Logprobs
    if enable_prefill_logprobs {
        envs.push(("REQUEST_LOGPROBS".into(), "1".into()));
    }

    // If huggingface_hub_cache is some, pass it to the shard
    // Useful when running inside a docker container
    if let Some(huggingface_hub_cache) = huggingface_hub_cache {
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
    };

    // If weights_cache_override is some, pass it to the shard
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = weights_cache_override {
        envs.push((
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // Enable experimental support for cuda graphs
    if !cuda_graphs.is_empty() {
        envs.push((
            "CUDA_GRAPHS".into(),
            cuda_graphs
                .into_iter()
                .map(|c| c.to_string())
                .collect::<Vec<_>>()
                .join(",")
                .into(),
        ));
    }

    // If disable_custom_kernels is true, pass it to the shard as an env var
    if disable_custom_kernels {
        envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
    }

    // Watermark Gamma
    if let Some(watermark_gamma) = watermark_gamma {
        envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
    }

    // Watermark Delta
    if let Some(watermark_delta) = watermark_delta {
        envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
    }

    // Start process
    tracing::info!("Starting shard");
    let mut p = match Command::new("text-generation-server")
        .args(shard_args)
        .env_clear()
        .envs(envs)
        .stdin(Stdio::piped())
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
        Ok(p) => p,
        Err(err) => {
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
            }
            {
                tracing::error!("{}", err);
            }

            status_sender.send(ShardStatus::Failed(rank)).unwrap();
            return;
        }
    };

    // Redirect STDOUT to the console
    let mut pstdin = p.stdin.take().unwrap();
    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());

    //stdout tracing thread
    thread::spawn(move || {
        log_lines(shard_stdout_reader);
    });
    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
        for line in shard_stderr_reader.lines().map_while(Result::ok) {
            err_sender.send(line).unwrap_or(());
        }
    });
    // We read stdin in another thread as it seems that lines() can block in some cases
    if LevelFilter::current() >= tracing::Level::DEBUG {
        thread::spawn(move || {
            let mut stdin = io::stdin(); // We get `Stdin` here.
            loop {
                let mut buffer = vec![0; 4096];
                if let Ok(n) = stdin.read(&mut buffer) {
                    if n > 0 {
                        let _ = pstdin.write_all(&buffer[..n]);
                    }
                }
            }
        });
    }

    let mut ready = false;
    let start_time = Instant::now();
    let mut wait_time = Instant::now();
    loop {
        // Process exited
        if let Some(exit_status) = p.try_wait().unwrap() {
            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

            tracing::error!("Shard complete standard error output:\n{err}");

            if let Some(signal) = exit_status.signal() {
                tracing::error!("Shard process was signaled to shutdown with signal {signal}");
            }

            status_sender.send(ShardStatus::Failed(rank)).unwrap();
            return;
        }

        // We received a shutdown signal
        if shutdown.load(Ordering::SeqCst) {
            terminate(
                "shard",
                p,
                Duration::from_secs(graceful_termination_timeout),
            )
            .unwrap();
            return;
        }

        // Shard is ready
        if uds.exists() && !ready {
            tracing::info!("Shard ready in {:?}", start_time.elapsed());
            status_sender.send(ShardStatus::Ready).unwrap();
            ready = true;
        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
            tracing::info!("Waiting for shard to be ready...");
            wait_time = Instant::now();
        }
        sleep(Duration::from_millis(100));
    }
}