backends/src/lib.rs (505 lines of code) (raw):

mod dtype; use hf_hub::api::tokio::{ApiError, ApiRepo}; use rand::Rng; use std::cmp::{max, min}; use std::env; use std::path::PathBuf; use std::process::Command; use std::sync::Arc; use std::thread::JoinHandle; use std::time::{Duration, Instant}; use text_embeddings_backend_core::{Backend as CoreBackend, Predictions}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{instrument, Span}; pub use crate::dtype::DType; pub use text_embeddings_backend_core::{ BackendError, Batch, Embedding, Embeddings, ModelType, Pool, }; #[cfg(feature = "candle")] use text_embeddings_backend_candle::CandleBackend; #[cfg(feature = "ort")] use text_embeddings_backend_ort::OrtBackend; #[cfg(feature = "python")] use text_embeddings_backend_python::PythonBackend; fn powers_of_two(max_value: usize) -> Vec<usize> { let mut result = Vec::new(); let mut power: usize = 1; while power <= max_value { result.push(power); power *= 2; } result } fn generate_bucket_sizes(bucket_size: usize, max_s: usize, base_exp: usize) -> Vec<usize> { let mut sizes = Vec::new(); let mut current = bucket_size; while current <= max_s { sizes.push(current); match current.checked_mul(base_exp) { Some(next) => current = next, None => break, } } sizes } fn is_hpu() -> bool { match Command::new("hl-smi") .args(["-Q", "name", "-f", "csv"]) .output() { Ok(output) => output.status.success(), Err(_) => false, } } #[derive(Debug, Clone)] pub struct Backend { /// Channel to communicate with the background thread backend_sender: mpsc::Sender<BackendCommand>, /// Health status health_receiver: watch::Receiver<bool>, _backend_thread: Arc<BackendThread>, pub padded_model: bool, pub max_batch_size: Option<usize>, pub model_type: ModelType, } impl Backend { pub async fn new( model_path: PathBuf, api_repo: Option<ApiRepo>, dtype: DType, model_type: ModelType, uds_path: String, otlp_endpoint: Option<String>, otlp_service_name: String, ) -> Result<Self, BackendError> { let (backend_sender, backend_receiver) = mpsc::channel(8); let backend = init_backend( model_path, api_repo, dtype, model_type.clone(), uds_path, otlp_endpoint, otlp_service_name, ) .await?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); let (health_sender, health_receiver) = watch::channel(false); let _backend_thread = Arc::new(BackendThread::new(backend, backend_receiver, health_sender)); Ok(Self { backend_sender, health_receiver, _backend_thread, padded_model, max_batch_size, model_type, }) } #[instrument(skip(self))] pub async fn warmup_hpu( &self, mut max_input_length: usize, max_token: usize, max_bs: Option<usize>, ) -> Result<(), BackendError> { let read_env_var = |key: &str, default: usize| -> usize { env::var(key) .ok() .map_or(default, |value| value.parse::<usize>().unwrap()) }; let seq_bucket_size: usize = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); let max_warmup_length: usize = read_env_var("MAX_WARMUP_SEQUENCE_LENGTH", 1024); let seq_len_exp_base: usize = read_env_var("SEQ_LEN_EXPONENT_BASE", 2); let max_batch_size = max_bs.unwrap_or_else(|| read_env_var("MAX_WARMUP_BATCH_SIZE", 8)); let mut batch_sizes: Vec<usize> = powers_of_two(max_batch_size); if let Some(&last) = batch_sizes.last() { if last < max_batch_size { batch_sizes.push(max_batch_size); } } if max_warmup_length > max_input_length { return Err(BackendError::Start( format!("max_warmup_length ({max_warmup_length}) exceeds model's max_input_length ({max_input_length}), you can modify this value adding `-e MAX_WARMUP_SEQUENCE_LENGTH=<new_warmup_length>` to your Docker run command") )); } if seq_bucket_size > max_warmup_length { return Err(BackendError::Start( format!("PAD_SEQUENCE_TO_MULTIPLE_OF ({seq_bucket_size}) exceeds model's max warmup length ({max_warmup_length}), you can modify these values adding `-e PAD_SEQUENCE_TO_MULTIPLE_OF=<new_value>` or `-e MAX_WARMUP_SEQUENCE_LENGTH=<new_value> to your Docker run command`") )); } max_input_length = std::cmp::min(max_input_length, max_warmup_length); let mut seq_lengths: Vec<usize> = generate_bucket_sizes( seq_bucket_size, max_input_length, seq_len_exp_base, ); if let Some(&last) = seq_lengths.last() { if last < max_input_length { seq_lengths.push(max_input_length); } } let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); for batch_size in &batch_sizes { for seq_length in &seq_lengths { shapes.push((*batch_size as u32, *seq_length as u32)); } } for shape in shapes.iter() { let batch = self.create_warmup_batch(*shape, max_token as u32); match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), }?; tracing::info!("finish warmup for batch: {}, length: {}", shape.0, shape.1); } Ok(()) } #[instrument(skip_all)] pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32) -> Batch { let (batch_size, length) = shape; let mut batched_input_ids = Vec::new(); let mut batched_token_type_ids = Vec::new(); let mut batched_position_ids = Vec::new(); let mut cumulative_seq_lengths = Vec::with_capacity(batch_size as usize + 1); let mut pooled_indices = Vec::with_capacity(batch_size as usize); cumulative_seq_lengths.push(0); let input_ids: Vec<u32> = (0..length) .map(|_| rand::rng().random_range(0..max_token)) .collect(); let token_type_ids: Vec<u32> = vec![0; length as usize]; let position_ids: Vec<u32> = (0..length).collect(); let mut current_length = 0; for batch_id in 0..batch_size { batched_input_ids.extend(input_ids.iter().cloned()); batched_token_type_ids.extend(token_type_ids.iter().cloned()); batched_position_ids.extend(position_ids.iter().cloned()); current_length += input_ids.len(); cumulative_seq_lengths.push(current_length as u32); pooled_indices.push(batch_id); } Batch { input_ids: batched_input_ids, token_type_ids: batched_token_type_ids, position_ids: batched_position_ids, cumulative_seq_lengths, max_length: length, pooled_indices, raw_indices: vec![], } } #[instrument(skip(self))] pub async fn warmup( &self, max_input_length: usize, max_batch_tokens: usize, max_batch_requests: Option<usize>, ) -> Result<(), BackendError> { if is_hpu() { return self .warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests) .await; } let mut input_ids = Vec::with_capacity(max_batch_tokens); let mut token_type_ids = Vec::with_capacity(max_batch_tokens); let mut position_ids = Vec::with_capacity(max_batch_tokens); let mut cumulative_seq_lengths = vec![0]; let mut pooled_indices = Vec::new(); let mut i = 0_u32; let mut remaining = max_batch_tokens; let mut cumulative_length = 0; let mut max_length = 0; while remaining > 0 { let request_length = min(remaining, max_input_length); cumulative_length += request_length; max_length = max(max_length, request_length as u32); input_ids.extend(vec![0; request_length]); token_type_ids.extend(vec![0; request_length]); position_ids.extend((0..request_length as u32).collect::<Vec<u32>>()); cumulative_seq_lengths.push(cumulative_length as u32); pooled_indices.push(i); i += 1; remaining = remaining.saturating_sub(max_input_length); if let Some(max_batch_requests) = &max_batch_requests { if i as usize == *max_batch_requests { break; } } } let batch = Batch { input_ids, token_type_ids, position_ids, cumulative_seq_lengths, max_length, pooled_indices, raw_indices: vec![], }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } } #[instrument(skip(self))] pub async fn health(&self) -> Result<(), BackendError> { if *self.health_receiver.borrow() { // The backend is healthy. Only do a basic health check by calling the // the underlying health method. let (sender, receiver) = oneshot::channel(); self.backend_sender .send(BackendCommand::Health(Span::current(), sender)) .await .expect("No backend receiver. This is a bug."); receiver.await.expect( "Backend blocking task dropped the sender without sending a response. This is a bug.", ) } else { // The backend is un-healthy or only just started. Do a more advanced health check // by calling the model forward on a test batch let batch = Batch { input_ids: vec![0], token_type_ids: vec![0], position_ids: vec![0], cumulative_seq_lengths: vec![0, 1], max_length: 1, pooled_indices: vec![0], raw_indices: vec![], }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } } } #[instrument(skip(self))] pub fn health_watcher(&self) -> watch::Receiver<bool> { self.health_receiver.clone() } #[instrument(skip_all)] pub async fn embed(&self, batch: Batch) -> Result<(Embeddings, Duration), BackendError> { let (sender, receiver) = oneshot::channel(); self.backend_sender .try_send(BackendCommand::Embed(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", ) } #[instrument(skip_all)] pub async fn predict(&self, batch: Batch) -> Result<(Predictions, Duration), BackendError> { let (sender, receiver) = oneshot::channel(); self.backend_sender .try_send(BackendCommand::Predict(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", ) } } #[allow(unused)] async fn init_backend( model_path: PathBuf, api_repo: Option<ApiRepo>, dtype: DType, model_type: ModelType, uds_path: String, otlp_endpoint: Option<String>, otlp_service_name: String, ) -> Result<Box<dyn CoreBackend + Send>, BackendError> { let mut backend_start_failed = false; if cfg!(feature = "ort") { #[cfg(feature = "ort")] { if let Some(api_repo) = api_repo.as_ref() { let start = std::time::Instant::now(); let model_files = download_onnx(api_repo) .await .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; match model_files.is_empty() { true => { tracing::error!("Model ONNX files not found in the repository. You can easily create ONNX files using the following scripts: https://gist.github.com/tomaarsen/4b00b0e3be8884efa64cfab9230b161f, or use this Space: https://huggingface.co/spaces/sentence-transformers/backend-export") } false => { tracing::info!("Model ONNX weights downloaded in {:?}", start.elapsed()) } } } let backend = OrtBackend::new(&model_path, dtype.to_string(), model_type.clone()); match backend { Ok(b) => return Ok(Box::new(b)), Err(err) => { tracing::error!("Could not start ORT backend: {err}"); backend_start_failed = true; } } } } if let Some(api_repo) = api_repo.as_ref() { if cfg!(feature = "python") || cfg!(feature = "candle") { let start = std::time::Instant::now(); if download_safetensors(api_repo).await.is_err() { tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); tracing::info!("Downloading `pytorch_model.bin`"); api_repo .get("pytorch_model.bin") .await .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; } tracing::info!("Model weights downloaded in {:?}", start.elapsed()); } } if cfg!(feature = "candle") { #[cfg(feature = "candle")] { let backend = CandleBackend::new(&model_path, dtype.to_string(), model_type.clone()); match backend { Ok(b) => return Ok(Box::new(b)), Err(err) => { tracing::error!("Could not start Candle backend: {err}"); backend_start_failed = true; } } } } if cfg!(feature = "python") { #[cfg(feature = "python")] { let backend = std::thread::spawn(move || { PythonBackend::new( model_path.to_str().unwrap().to_string(), dtype.to_string(), model_type, uds_path, otlp_endpoint, otlp_service_name, ) }) .join() .expect("Python Backend management thread failed"); match backend { Ok(b) => return Ok(Box::new(b)), Err(err) => { tracing::error!("Could not start Python backend: {err}"); backend_start_failed = true; } } } } if backend_start_failed { Err(BackendError::Start( "Could not start a suitable backend".to_string(), )) } else { Err(BackendError::NoBackend) } } #[derive(Debug)] struct BackendThread(Option<JoinHandle<()>>); impl BackendThread { fn new( backend: Box<dyn CoreBackend + Send>, mut backend_receiver: mpsc::Receiver<BackendCommand>, health_sender: watch::Sender<bool>, ) -> Self { let handle = std::thread::spawn(move || { while let Some(cmd) = backend_receiver.blocking_recv() { let start = Instant::now(); let mut healthy = false; match cmd { BackendCommand::Health(span, sender) => { let _span = span.entered(); let _ = sender.send(backend.health().map(|_| healthy = true)); } BackendCommand::Embed(batch, span, sender) => { let _span = span.entered(); let _ = sender.send(backend.embed(batch).map(|e| { healthy = true; (e, start.elapsed()) })); } BackendCommand::Predict(batch, span, sender) => { let _span = span.entered(); let _ = sender.send(backend.predict(batch).map(|e| { healthy = true; (e, start.elapsed()) })); } }; let _ = health_sender.send(healthy); } }); Self(Some(handle)) } } impl Drop for BackendThread { fn drop(&mut self) { self.0.take().unwrap().join().unwrap(); } } enum BackendCommand { Health(Span, oneshot::Sender<Result<(), BackendError>>), Embed( Batch, Span, oneshot::Sender<Result<(Embeddings, Duration), BackendError>>, ), Predict( Batch, Span, #[allow(clippy::type_complexity)] oneshot::Sender<Result<(Predictions, Duration), BackendError>>, ), } async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> { // Single file tracing::info!("Downloading `model.safetensors`"); match api.get("model.safetensors").await { Ok(p) => return Ok(vec![p]), Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err), }; // Sharded weights // Download and parse index file tracing::info!("Downloading `model.safetensors.index.json`"); let index_file = api.get("model.safetensors.index.json").await?; let index_file_string: String = std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted"); let json: serde_json::Value = serde_json::from_str(&index_file_string) .expect("model.safetensors.index.json is corrupted"); let weight_map = match json.get("weight_map") { Some(serde_json::Value::Object(map)) => map, _ => panic!("model.safetensors.index.json is corrupted"), }; let mut safetensors_filenames = std::collections::HashSet::new(); for value in weight_map.values() { if let Some(file) = value.as_str() { safetensors_filenames.insert(file.to_string()); } } // Download weight files let mut safetensors_files = Vec::new(); for n in safetensors_filenames { tracing::info!("Downloading `{}`", n); safetensors_files.push(api.get(&n).await?); } Ok(safetensors_files) } #[cfg(feature = "ort")] async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> { let mut model_files: Vec<PathBuf> = Vec::new(); tracing::info!("Downloading `model.onnx`"); match api.get("model.onnx").await { Ok(p) => model_files.push(p), Err(err) => { tracing::warn!("Could not download `model.onnx`: {err}"); tracing::info!("Downloading `onnx/model.onnx`"); match api.get("onnx/model.onnx").await { Ok(p) => model_files.push(p.parent().unwrap().to_path_buf()), Err(err) => tracing::warn!("Could not download `onnx/model.onnx`: {err}"), }; } }; tracing::info!("Downloading `model.onnx_data`"); match api.get("model.onnx_data").await { Ok(p) => model_files.push(p), Err(err) => { tracing::warn!("Could not download `model.onnx_data`: {err}"); tracing::info!("Downloading `onnx/model.onnx_data`"); match api.get("onnx/model.onnx_data").await { Ok(p) => model_files.push(p.parent().unwrap().to_path_buf()), Err(err) => tracing::warn!("Could not download `onnx/model.onnx_data`: {err}"), } } } Ok(model_files) }