in launcher/src/main.rs [923:1249]
fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_graphs: Vec<usize>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
max_total_tokens: Option<usize>,
max_batch_size: Option<usize>,
max_input_tokens: Option<usize>,
lora_adapters: Option<String>,
enable_prefill_logprobs: bool,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
graceful_termination_timeout: u64,
_shutdown_sender: mpsc::Sender<()>,
) {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
// Get UDS path
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
// Clean previous runs
if uds.exists() {
fs::remove_file(uds).unwrap();
}
// Process args
let mut shard_args = vec![
"serve".to_string(),
model_id,
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
log_level.to_string().to_uppercase(),
"--json-output".to_string(),
];
// Activate trust remote code
if trust_remote_code {
shard_args.push("--trust-remote-code".to_string());
}
// Activate tensor parallelism
if world_size > 1 {
shard_args.push("--sharded".to_string());
}
if let Some(quantize) = quantize {
shard_args.push("--quantize".to_string());
shard_args.push(quantize.to_string())
}
if let Some(speculate) = speculate {
shard_args.push("--speculate".to_string());
shard_args.push(speculate.to_string())
}
if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
}
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
shard_args.push(revision)
}
let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry Endpoint
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint);
}
// OpenTelemetry Service Name
shard_args.push("--otlp-service-name".to_string());
shard_args.push(otlp_service_name);
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
if let Some(max_input_tokens) = max_input_tokens {
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
}
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
envs.push(("MASTER_ADDR".into(), master_addr.into()));
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
envs.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
// CUDA memory fraction
envs.push((
"CUDA_MEMORY_FRACTION".into(),
cuda_memory_fraction.to_string().into(),
));
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(),
));
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope {
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
if let Some(max_total_tokens) = max_total_tokens {
envs.push((
"MAX_TOTAL_TOKENS".into(),
max_total_tokens.to_string().into(),
));
}
if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}
// Logprobs
if enable_prefill_logprobs {
envs.push(("REQUEST_LOGPROBS".into(), "1".into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override {
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// Enable experimental support for cuda graphs
if !cuda_graphs.is_empty() {
envs.push((
"CUDA_GRAPHS".into(),
cuda_graphs
.into_iter()
.map(|c| c.to_string())
.collect::<Vec<_>>()
.join(",")
.into(),
));
}
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.env_clear()
.envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
{
Ok(p) => p,
Err(err) => {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
{
tracing::error!("{}", err);
}
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
};
// Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || {
log_lines(shard_stdout_reader);
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().map_while(Result::ok) {
err_sender.send(line).unwrap_or(());
}
});
// We read stdin in another thread as it seems that lines() can block in some cases
if LevelFilter::current() >= tracing::Level::DEBUG {
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
}
let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
tracing::error!("Shard complete standard error output:\n{err}");
if let Some(signal) = exit_status.signal() {
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
}
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
// We received a shutdown signal
if shutdown.load(Ordering::SeqCst) {
terminate(
"shard",
p,
Duration::from_secs(graceful_termination_timeout),
)
.unwrap();
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
}
}