fn download_convert_model()

in launcher/src/main.rs [1430:1572]


fn download_convert_model(
    model_id: &str,
    revision: Option<&str>,
    trust_remote_code: bool,
    huggingface_hub_cache: Option<&str>,
    weights_cache_override: Option<&str>,
    running: Arc<AtomicBool>,
    merge_lora: bool,
) -> Result<(), LauncherError> {
    // Enter download tracing span
    let _span = tracing::span!(tracing::Level::INFO, "download").entered();

    let mut download_args = vec![
        "download-weights".to_string(),
        model_id.to_string(),
        "--extension".to_string(),
        ".safetensors".to_string(),
        "--logger-level".to_string(),
        "INFO".to_string(),
        "--json-output".to_string(),
    ];

    if merge_lora {
        download_args.push("--merge-lora".to_string());
    }

    // Model optional revision
    if let Some(revision) = &revision {
        download_args.push("--revision".to_string());
        download_args.push(revision.to_string())
    }

    // Trust remote code for automatic peft fusion
    if trust_remote_code {
        download_args.push("--trust-remote-code".to_string());
    }

    // Copy current process env
    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

    // Remove LOG_LEVEL if present
    envs.retain(|(name, _)| name != "LOG_LEVEL");

    // Disable progress bar
    envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

    // If huggingface_hub_cache is set, pass it to the download process
    // Useful when running inside a docker container
    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
        envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
    };

    // Enable hf transfer for insane download speeds
    let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
    envs.push((
        "HF_HUB_ENABLE_HF_TRANSFER".into(),
        enable_hf_transfer.into(),
    ));

    // Parse Inference API token
    if let Ok(api_token) = env::var("HF_API_TOKEN") {
        envs.push(("HF_TOKEN".into(), api_token.into()))
    };

    // If args.weights_cache_override is some, pass it to the download process
    // Useful when running inside a HuggingFace Inference Endpoint
    if let Some(weights_cache_override) = &weights_cache_override {
        envs.push((
            "WEIGHTS_CACHE_OVERRIDE".into(),
            weights_cache_override.into(),
        ));
    };

    // Start process
    tracing::info!("Starting check and download process for {model_id}");
    let mut download_process = match Command::new("text-generation-server")
        .args(download_args)
        .env_clear()
        .envs(envs)
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .process_group(0)
        .spawn()
    {
        Ok(p) => p,
        Err(err) => {
            if err.kind() == io::ErrorKind::NotFound {
                tracing::error!("text-generation-server not found in PATH");
                tracing::error!("Please install it with `make install-server`")
            } else {
                tracing::error!("{}", err);
            }

            return Err(LauncherError::DownloadError);
        }
    };

    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());

    thread::spawn(move || {
        log_lines(download_stdout);
    });

    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

    // We read stderr in another thread as it seems that lines() can block in some cases
    let (err_sender, err_receiver) = mpsc::channel();
    thread::spawn(move || {
        for line in download_stderr.lines().map_while(Result::ok) {
            err_sender.send(line).unwrap_or(());
        }
    });

    loop {
        if let Some(status) = download_process.try_wait().unwrap() {
            if status.success() {
                tracing::info!("Successfully downloaded weights for {model_id}");
                break;
            }

            let mut err = String::new();
            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
                err = err + "\n" + &line;
            }

            if let Some(signal) = status.signal() {
                tracing::error!(
                    "Download process was signaled to shutdown with signal {signal}: {err}"
                );
            } else {
                tracing::error!("Download encountered an error: {err}");
            }

            return Err(LauncherError::DownloadError);
        }
        if !running.load(Ordering::SeqCst) {
            terminate("download", download_process, Duration::from_secs(10)).unwrap();
            return Ok(());
        }
        sleep(Duration::from_millis(100));
    }
    Ok(())
}