backends/candle/src/models/distilbert.rs (589 lines of code) (raw):

use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; use crate::models::Model; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct DistilBertConfig { pub vocab_size: usize, pub dim: usize, pub n_layers: usize, pub n_heads: usize, pub hidden_dim: usize, pub activation: HiddenAct, pub max_position_embeddings: usize, pub pad_token_id: usize, pub model_type: Option<String>, pub classifier_dropout: Option<f64>, pub id2label: Option<HashMap<String, String>>, } #[derive(Debug)] pub struct DistilBertEmbeddings { word_embeddings: Embedding, position_embeddings: Embedding, layer_norm: LayerNorm, span: tracing::Span, } impl DistilBertEmbeddings { pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { Ok(Self { word_embeddings: Embedding::new( vb.pp("word_embeddings") .get((config.vocab_size, config.dim), "weight")?, config.dim, ), position_embeddings: Embedding::new( vb.pp("position_embeddings") .get((config.max_position_embeddings, config.dim), "weight")?, config.dim, ), layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config.dim, 1e-12f32)?, span: tracing::span!(tracing::Level::TRACE, "embeddings"), }) } pub fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let input_embeddings = self.word_embeddings.forward(input_ids)?; let position_embeddings = self.position_embeddings.forward(position_ids)?; let embeddings = self .layer_norm .forward(&input_embeddings, Some(&position_embeddings))?; Ok(embeddings) } } #[derive(Debug)] struct DistilBertAttention { qkv_linear: Linear, dense: Linear, num_attention_heads: usize, attention_head_size: usize, softmax_scale: f64, span: tracing::Span, } impl DistilBertAttention { pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let attention_head_size = config.dim / config.n_heads; let all_head_size = config.n_heads * attention_head_size; let hidden_size = config.dim; let query_weight = vb.pp("q_lin").get((all_head_size, hidden_size), "weight")?; let query_bias = vb.pp("q_lin").get(all_head_size, "bias")?; let key_weight = vb.pp("k_lin").get((all_head_size, hidden_size), "weight")?; let key_bias = vb.pp("k_lin").get(all_head_size, "bias")?; let value_weight = vb.pp("v_lin").get((all_head_size, hidden_size), "weight")?; let value_bias = vb.pp("v_lin").get(all_head_size, "bias")?; let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; let qkv_bias = Tensor::cat(&[&query_bias, &key_bias, &value_bias], 0)?; let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); let dense_weight = vb.pp("out_lin").get((hidden_size, hidden_size), "weight")?; let dense_bias = vb.pp("out_lin").get(hidden_size, "bias")?; let dense = Linear::new(dense_weight, Some(dense_bias), None); let softmax_scale = 1. / (attention_head_size as f64).sqrt(); Ok(Self { qkv_linear, dense, num_attention_heads: config.n_heads, attention_head_size, softmax_scale, span: tracing::span!(tracing::Level::TRACE, "attention"), }) } fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> { let _enter = self.span.enter(); let device = hidden_states.device(); let qkv = self.qkv_linear.forward(hidden_states)?; let mut new_qkv_shape = qkv.dims().to_vec(); new_qkv_shape.pop(); new_qkv_shape.push(self.num_attention_heads * 3); new_qkv_shape.push(self.attention_head_size); let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; let qkv = qkv.chunk(3, 1)?; let query_layer = &qkv[0].contiguous()?; let key_layer = &qkv[1].contiguous()?; let value_layer = &qkv[2]; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { #[cfg(feature = "cuda")] { // cuBLASLt batch matmul implementation requires inputs to be dims3 let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; let key_layer = key_layer.flatten(0, 1)?; let query_layer = query_layer.flatten(0, 1)?; let value_layer = value_layer.flatten(0, 1)?; let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; // If attention_bias is set, we fuse the add by giving it as the output matrix // and setting beta to 1.0 let beta = match attention_bias.is_some() { true => Some(1.0), false => None, }; // Batch matrix multiplication // Fuse softmax scale and attention_bias add let attention_scores = cublaslt.batch_matmul( &key_layer, &query_layer, attention_bias.as_ref(), Some(self.softmax_scale as f32), beta, None, None, )?; let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; let context_layer = cublaslt.batch_matmul( &value_layer.t()?.contiguous()?, &attention_probs, // We save one allocation Some(&query_layer), None, None, None, None, )?; // Reshape to dims4 context_layer.reshape(( batch_size, self.num_attention_heads, seq_len, self.attention_head_size, )) } #[cfg(not(feature = "cuda"))] { candle::bail!("`cuda` feature is not enabled") } } else { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let mut attention_scores = (attention_scores * self.softmax_scale)?; if let Some(attention_bias) = attention_bias { attention_scores = attention_scores.add(attention_bias)?; } let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; attention_probs.matmul(&value_layer.contiguous()?) }?; let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; let hidden_states = self.dense.forward(&context_layer)?; Ok(hidden_states) } } #[derive(Debug)] pub struct DistilBertMLP { lin1: Linear, lin2: Linear, span: tracing::Span, } impl DistilBertMLP { pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let lin1_weight = vb .pp("lin1") .get((config.hidden_dim, config.dim), "weight")?; let lin1_bias = vb.pp("lin1").get(config.hidden_dim, "bias")?; let lin1 = Linear::new( lin1_weight, Some(lin1_bias), Some(config.activation.clone()), ); let lin2_weight = vb .pp("lin2") .get((config.dim, config.hidden_dim), "weight")?; let lin2_bias = vb.pp("lin2").get(config.dim, "bias")?; let lin2 = Linear::new(lin2_weight, Some(lin2_bias), None); Ok(Self { lin1, lin2, span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let hidden_states = self.lin1.forward(hidden_states)?; self.lin2.forward(&hidden_states) } } #[derive(Debug)] struct DistilBertBlock { attention: DistilBertAttention, mlp: DistilBertMLP, post_attention_layer_norm: LayerNorm, output_layer_norm: LayerNorm, span: tracing::Span, } impl DistilBertBlock { pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let attention = DistilBertAttention::load(vb.pp("attention"), config)?; let mlp = DistilBertMLP::load(vb.pp("ffn"), config)?; let post_attention_layer_norm = LayerNorm::load(vb.pp("sa_layer_norm"), config.dim, 1e-12f32)?; let output_layer_norm = LayerNorm::load(vb.pp("output_layer_norm"), config.dim, 1e-12f32)?; Ok(Self { attention, mlp, post_attention_layer_norm, output_layer_norm, span: tracing::span!(tracing::Level::TRACE, "layer"), }) } pub fn forward( &self, hidden_states: &Tensor, attention_bias: Option<&Tensor>, ) -> Result<Tensor> { let _enter = self.span.enter(); let attn_output = self.attention.forward(hidden_states, attention_bias)?; let hidden_states = self .post_attention_layer_norm .forward(hidden_states, Some(&attn_output))?; let mlp_out = self.mlp.forward(&hidden_states)?; self.output_layer_norm .forward(&hidden_states, Some(&mlp_out)) } } #[derive(Debug)] struct DistilBertEncoder { layers: Vec<DistilBertBlock>, span: tracing::Span, } impl DistilBertEncoder { pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let layers = (0..config.n_layers) .map(|index| DistilBertBlock::load(vb.pp(format!("layer.{index}")), config)) .collect::<Result<Vec<_>>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(DistilBertEncoder { layers, span }) } fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/... for layer in self.layers.iter() { hidden_states = layer.forward(&hidden_states, attention_bias)?; } Ok(hidden_states) } } pub trait ClassificationHead { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>; } pub struct DistilBertClassificationHead { pre_classifier: Linear, classifier: Linear, span: tracing::Span, } impl DistilBertClassificationHead { pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let n_classes = match &config.id2label { None => candle::bail!("`id2label` must be set for classifier models"), Some(id2label) => id2label.len(), }; let pre_classifier_weight = vb .pp("pre_classifier") .get((config.dim, config.dim), "weight")?; let pre_classifier_bias = vb.pp("pre_classifier").get(config.dim, "bias")?; let pre_classifier = Linear::new(pre_classifier_weight, Some(pre_classifier_bias), None); let classifier_weight = vb.pp("classifier").get((n_classes, config.dim), "weight")?; let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); Ok(Self { pre_classifier, classifier, span: tracing::span!(tracing::Level::TRACE, "classifier"), }) } } impl ClassificationHead for DistilBertClassificationHead { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let hidden_states = hidden_states.unsqueeze(1)?; let hidden_states = self.pre_classifier.forward(&hidden_states)?; let hidden_states = hidden_states.relu()?; let hidden_states = self.classifier.forward(&hidden_states)?; let hidden_states = hidden_states.squeeze(1)?; Ok(hidden_states) } } #[derive(Debug)] pub struct DistilBertSpladeHead { vocab_transform: Linear, vocab_projector: Linear, vocab_layer_norm: LayerNorm, span: tracing::Span, } impl DistilBertSpladeHead { pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> { let vocab_transform_weight = vb .pp("vocab_transform") .get((config.dim, config.dim), "weight")?; let vocab_transform_bias = vb.pp("vocab_transform").get(config.dim, "bias")?; let vocab_transform = Linear::new( vocab_transform_weight, Some(vocab_transform_bias), Some(config.activation.clone()), ); // When `pytorch_model.bin` originally contains `vocab_projector.weight` but the tensor // content shares the memory with the content on `distilbert.embeddings.word_embeddings.weight`, // e.g. a subset of the original tensor, when converting the file from BIN to Safentensors // the latter tensor that shares the memory with the previous will be removed let vocab_projector_weight = if vb.contains_tensor("vocab_projector.weight") { vb.pp("vocab_projector") .get((config.vocab_size, config.dim), "weight")? } else { vb.pp("distilbert.embeddings.word_embeddings") .get((config.vocab_size, config.dim), "weight")? }; let vocab_projector_bias = vb.pp("vocab_projector").get(config.vocab_size, "bias")?; let vocab_projector = Linear::new( vocab_projector_weight, Some(vocab_projector_bias), Some(HiddenAct::Relu), ); let vocab_layer_norm = LayerNorm::load(vb.pp("vocab_layer_norm"), config.dim, 1e-12f32)?; Ok(Self { vocab_transform, vocab_projector, vocab_layer_norm, span: tracing::span!(tracing::Level::TRACE, "splade"), }) } pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let hidden_states = self.vocab_transform.forward(hidden_states)?; let hidden_states = self.vocab_layer_norm.forward(&hidden_states, None)?; let hidden_states = self.vocab_projector.forward(&hidden_states)?; (1.0 + hidden_states)?.log() } } pub struct DistilBertModel { embeddings: DistilBertEmbeddings, encoder: DistilBertEncoder, pool: Pool, classifier: Option<Box<dyn ClassificationHead + Send>>, splade: Option<DistilBertSpladeHead>, num_attention_heads: usize, device: Device, dtype: DType, span: tracing::Span, } impl DistilBertModel { pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result<Self> { let (pool, classifier) = match model_type { // Classifier models always use CLS pooling ModelType::Classifier => { let pool = Pool::Cls; let classifier: Box<dyn ClassificationHead + Send> = Box::new(DistilBertClassificationHead::load(vb.clone(), config)?); (pool, Some(classifier)) } ModelType::Embedding(pool) => { if pool == Pool::LastToken { candle::bail!("`last_token` is not supported for DistilBert"); } (pool, None) } }; let (embeddings, encoder) = match ( DistilBertEmbeddings::load(vb.pp("embeddings"), config), DistilBertEncoder::load(vb.pp("encoder"), config), ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { if let (Ok(embeddings), Ok(encoder)) = ( DistilBertEmbeddings::load(vb.pp("distilbert.embeddings"), config), DistilBertEncoder::load(vb.pp("distilbert.transformer"), config), ) { (embeddings, encoder) } else if let (Ok(embeddings), Ok(encoder)) = ( DistilBertEmbeddings::load(vb.pp("embeddings"), config), DistilBertEncoder::load(vb.pp("transformer"), config), ) { (embeddings, encoder) } else { return Err(err); } } }; let splade = if pool == Pool::Splade { Some(DistilBertSpladeHead::load(vb.clone(), config)?) } else { None }; Ok(Self { embeddings, encoder, pool, classifier, splade, num_attention_heads: config.n_heads, device: vb.device().clone(), dtype: vb.dtype(), span: tracing::span!(tracing::Level::TRACE, "model"), }) } pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> { let _enter = self.span.enter(); let batch_size = batch.len(); let max_length = batch.max_length as usize; let shape = (batch_size, max_length); let (input_ids, position_ids, input_lengths, attention_bias, attention_mask) = if batch_size > 1 { // Prepare padded batch let elems = batch_size * max_length; let mut input_ids = Vec::with_capacity(elems); let mut position_ids = Vec::with_capacity(elems); let mut attention_mask = Vec::with_capacity(elems); let mut attention_bias = Vec::with_capacity(elems); let mut input_lengths = Vec::with_capacity(batch_size); // Bool to know if we need to use the attention mask let mut masking = false; for i in 0..batch_size { let start = batch.cumulative_seq_lengths[i] as usize; let end = batch.cumulative_seq_lengths[i + 1] as usize; let seq_length = (end - start) as u32; input_lengths.push(seq_length as f32); // Copy values for j in start..end { input_ids.push(batch.input_ids[j]); position_ids.push(batch.position_ids[j]); attention_mask.push(1.0_f32); attention_bias.push(0.0); } // Add padding if needed let padding = batch.max_length - seq_length; if padding > 0 { // Set bool to use attention mask masking = true; for _ in 0..padding { input_ids.push(0); position_ids.push(0); attention_mask.push(0.0_f32); attention_bias.push(f32::NEG_INFINITY); } } } let (attention_bias, attention_mask) = match masking { true => { // We only need the mask if we use mean pooling // For CLS pooling, the bias is enough let attention_mask = if self.pool == Pool::Mean { let attention_mask = Tensor::from_vec( attention_mask, (batch_size, max_length, 1), &self.device, )? .to_dtype(self.dtype)?; Some(attention_mask) } else { None }; let attention_bias = Tensor::from_vec( attention_bias, (batch_size, 1, 1, max_length), &self.device, )? .to_dtype(self.dtype)?; // Broadcast once instead of at every layer let attention_bias = attention_bias .broadcast_as(( batch_size, self.num_attention_heads, max_length, max_length, ))? .contiguous()?; (Some(attention_bias), attention_mask) } false => (None, None), }; ( input_ids, position_ids, input_lengths, attention_bias, attention_mask, ) } else { ( batch.input_ids, batch.position_ids, vec![batch.max_length as f32], None, None, ) }; // Create CPU tensors let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; let input_lengths = Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; let embedding_output = self.embeddings.forward(&input_ids, &position_ids)?; let outputs = self .encoder .forward(&embedding_output, attention_bias.as_ref())?; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); let pooled_embeddings = if has_pooling_requests { let pooled_indices_length = batch.pooled_indices.len(); let mut outputs = outputs.clone(); // Only use pooled_indices if at least one member of the batch ask for raw embeddings let pooled_indices = if has_raw_requests { let pooled_indices = Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; // Select values in the batch outputs = outputs.index_select(&pooled_indices, 0)?; Some(pooled_indices) } else { None }; let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.i((.., 0))?, // Last token pooling is not supported for this model Pool::LastToken => unreachable!(), // Mean pooling Pool::Mean => { if let Some(ref attention_mask) = attention_mask { let mut attention_mask = attention_mask.clone(); if let Some(pooled_indices) = pooled_indices { // Select values in the batch attention_mask = attention_mask.index_select(&pooled_indices, 0)?; }; // Mask padded values outputs = outputs.broadcast_mul(&attention_mask)?; } (outputs.sum(1)?.broadcast_div(&input_lengths))? } Pool::Splade => { // Unwrap is safe here let splade_head = self.splade.as_ref().unwrap(); let mut relu_log = splade_head.forward(&outputs)?; if let Some(ref attention_mask) = attention_mask { let mut attention_mask = attention_mask.clone(); if let Some(pooled_indices) = pooled_indices { // Select values in the batch attention_mask = attention_mask.index_select(&pooled_indices, 0)?; }; // Mask padded values relu_log = relu_log.broadcast_mul(&attention_mask)?; } relu_log.max(1)? } }; Some(pooled_embeddings) } else { None }; let raw_embeddings = if has_raw_requests { // Reshape outputs let (b, l, h) = outputs.shape().dims3()?; let outputs = outputs.reshape((b * l, h))?; // We need to remove the padding tokens only if batch_size > 1 and there are some // member of the batch that require pooling // or if batch_size > 1 and the members of the batch have different lengths if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length); for i in batch.raw_indices.into_iter() { let start = i * batch.max_length; let i = i as usize; let length = batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; for j in start..start + length { // Add indices for the tokens of this specific member of the batch final_indices.push(j); } } let final_indices_length = final_indices.len(); let final_indices = Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices Some(outputs.index_select(&final_indices, 0)?) } else { Some(outputs) } } else { None }; Ok((pooled_embeddings, raw_embeddings)) } } impl Model for DistilBertModel { fn is_padded(&self) -> bool { true } fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> { self.forward(batch) } fn predict(&self, batch: Batch) -> Result<Tensor> { match &self.classifier { None => candle::bail!("`predict` is not implemented for this model"), Some(classifier) => { let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; let pooled_embeddings = pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); classifier.forward(&pooled_embeddings) } } } }