in launcher/src/main.rs [1575:1685]
fn spawn_shards(
num_shard: usize,
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: Option<usize>,
max_input_tokens: Option<usize>,
quantize: Option<Quantization>,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
status_receiver: &mpsc::Receiver<ShardStatus>,
status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>,
graceful_termination_timeout: u64,
) -> Result<(), LauncherError> {
// Start shard processes
for rank in 0..num_shard {
let model_id = args.model_id.clone();
let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone();
let master_addr = args.master_addr.clone();
let huggingface_hub_cache = args.huggingface_hub_cache.clone();
let weights_cache_override = args.weights_cache_override.clone();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_graphs_clone = cuda_graphs.clone();
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
let enable_prefill_logprobs = args.enable_prefill_logprobs;
thread::spawn(move || {
shard_manager(
model_id,
revision,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
num_shard,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
watermark_gamma,
watermark_delta,
cuda_graphs_clone,
cuda_memory_fraction,
rope_scaling,
rope_factor,
max_total_tokens,
max_batch_size,
max_input_tokens,
lora_adapters,
enable_prefill_logprobs,
otlp_endpoint,
otlp_service_name,
max_log_level,
status_sender,
shutdown,
graceful_termination_timeout,
shutdown_sender,
)
});
}
drop(shutdown_sender);
// Wait for shard to start
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {
shard_ready += 1;
if shard_ready == num_shard {
break;
}
}
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {rank} failed to start");
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart);
}
Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardDisconnected);
}
}
}
Ok(())
}