in launcher/src/main.rs [1430:1572]
fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
merge_lora: bool,
) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![
"download-weights".to_string(),
model_id.to_string(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
if merge_lora {
download_args.push("--merge-lora".to_string());
}
// Model optional revision
if let Some(revision) = &revision {
download_args.push("--revision".to_string());
download_args.push(revision.to_string())
}
// Trust remote code for automatic peft fusion
if trust_remote_code {
download_args.push("--trust-remote-code".to_string());
}
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(),
));
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &weights_cache_override {
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// Start process
tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server")
.args(download_args)
.env_clear()
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
{
Ok(p) => p,
Err(err) => {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
} else {
tracing::error!("{}", err);
}
return Err(LauncherError::DownloadError);
}
};
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(download_stdout);
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in download_stderr.lines().map_while(Result::ok) {
err_sender.send(line).unwrap_or(());
}
});
loop {
if let Some(status) = download_process.try_wait().unwrap() {
if status.success() {
tracing::info!("Successfully downloaded weights for {model_id}");
break;
}
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
if let Some(signal) = status.signal() {
tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}"
);
} else {
tracing::error!("Download encountered an error: {err}");
}
return Err(LauncherError::DownloadError);
}
if !running.load(Ordering::SeqCst) {
terminate("download", download_process, Duration::from_secs(10)).unwrap();
return Ok(());
}
sleep(Duration::from_millis(100));
}
Ok(())
}