backends/candle/src/models/flash_mistral.rs (358 lines of code) (raw):

use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; use crate::models::{MistralConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct MistralAttention { qkv_linear: Linear, o_proj: Linear, window_size_left: Option<usize>, num_attention_heads: usize, num_key_value_heads: usize, attention_head_size: usize, softmax_scale: f32, span: tracing::Span, } impl MistralAttention { pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result<Self> { let window_size_left = config.sliding_window; let num_attention_heads = config.num_attention_heads; let attention_head_size = config.hidden_size / config.num_attention_heads; let num_key_value_heads = config.num_key_value_heads; let hidden_size = config.hidden_size; let query_weight = vb.pp("q_proj").get((hidden_size, hidden_size), "weight")?; let key_weight = vb.pp("k_proj").get( (num_key_value_heads * attention_head_size, hidden_size), "weight", )?; let value_weight = vb.pp("v_proj").get( (num_key_value_heads * attention_head_size, hidden_size), "weight", )?; let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; let qkv_linear = Linear::new(qkv_weight, None, None); let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?; let o_proj = Linear::new(o_proj_weight, None, None); let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; Ok(Self { qkv_linear, o_proj, window_size_left, num_attention_heads, num_key_value_heads, attention_head_size, softmax_scale, span: tracing::span!(tracing::Level::TRACE, "attention"), }) } pub fn forward( &self, hidden_states: &Tensor, cu_seqlens: &Tensor, cos: &Tensor, sin: &Tensor, max_s: usize, ) -> Result<Tensor> { let _enter = self.span.enter(); let qkv = self.qkv_linear.forward(hidden_states)?; // Reshape to [tokens, heads, head_size] let mut new_qkv_shape = qkv.dims().to_vec(); new_qkv_shape.pop(); new_qkv_shape.push(self.num_attention_heads + 2 * self.num_key_value_heads); new_qkv_shape.push(self.attention_head_size); let qkv = qkv.reshape(new_qkv_shape)?; // Split qkv tensor let q = qkv.narrow(1, 0, self.num_attention_heads)?; let k = qkv.narrow(1, self.num_attention_heads, self.num_key_value_heads)?; let v = qkv.narrow( 1, self.num_attention_heads + self.num_key_value_heads, self.num_key_value_heads, )?; apply_rotary_inplace(&q, &k, &cos, &sin, true)?; let attention = flash_attn_varlen( &q, &k, &v, None, cu_seqlens, cu_seqlens, max_s, max_s, self.softmax_scale, true, self.window_size_left, None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; self.o_proj.forward(&attention) } } struct MistralMLP { gate_up_proj: Linear, down_proj: Linear, act: HiddenAct, intermediate_size: usize, span: tracing::Span, } impl MistralMLP { pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result<Self> { let intermediate_size = config.intermediate_size; let gate_proj_weight = vb .pp("gate_proj") .get((intermediate_size, config.hidden_size), "weight")?; let up_proj_weight = vb .pp("up_proj") .get((intermediate_size, config.hidden_size), "weight")?; let gate_up_proj_weight = Tensor::cat(&[&gate_proj_weight, &up_proj_weight], 0)?; let gate_up_proj = Linear::new(gate_up_proj_weight, None, None); let down_proj_weight = vb .pp("down_proj") .get((config.hidden_size, intermediate_size), "weight")?; let down_proj = Linear::new(down_proj_weight, None, None); Ok(Self { gate_up_proj, down_proj, intermediate_size, act: config.hidden_act.clone(), span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let gate_up_states = self.gate_up_proj.forward(hidden_states)?; let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; let gate_states = self.act.forward(&gate_states)?; let r = self.down_proj.forward(&(gate_states * up_states)?); r } } struct MistralLayer { attention: MistralAttention, mlp: MistralMLP, input_layer_norm: RMSNorm, post_attention_layer_norm: RMSNorm, span: tracing::Span, } impl MistralLayer { pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result<Self> { let attention = MistralAttention::load(vb.pp("self_attn"), config)?; let mlp = MistralMLP::load(vb.pp("mlp"), config)?; let input_layer_norm = RMSNorm::load( vb.pp("input_layernorm"), config.hidden_size, config.rms_norm_eps, )?; let post_attention_layer_norm = RMSNorm::load( vb.pp("post_attention_layernorm"), config.hidden_size, config.rms_norm_eps, )?; Ok(Self { attention, mlp, input_layer_norm, post_attention_layer_norm, span: tracing::span!(tracing::Level::TRACE, "layer"), }) } pub fn forward( &self, hidden_states: &Tensor, residual: Option<&Tensor>, cu_seqlens: &Tensor, cos: &Tensor, sin: &Tensor, max_s: usize, ) -> Result<(Tensor, Tensor)> { let _enter = self.span.enter(); let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?; let attn_output = self.attention .forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?; let (normed_attn_res_output, attn_res) = self .post_attention_layer_norm .forward(&attn_output, Some(&res))?; let mlp_output = self.mlp.forward(&normed_attn_res_output)?; Ok((mlp_output, attn_res)) } } pub struct FlashMistralModel { embeddings: Embedding, layers: Vec<MistralLayer>, norm: RMSNorm, cos_cache: Tensor, sin_cache: Tensor, pool: Pool, pub device: Device, span: tracing::Span, } impl FlashMistralModel { pub fn load(vb: VarBuilder, config: &MistralConfig, model_type: ModelType) -> Result<Self> { match vb.device() { Device::Cuda(_) => {} _ => candle::bail!("FlashMistral requires Cuda"), } if vb.dtype() != DType::F16 { candle::bail!("FlashMistral requires DType::F16") } let pool = match model_type { ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for Mistral") } ModelType::Embedding(pool) => pool, }; let embeddings = Embedding::new( vb.pp("embed_tokens") .get((config.vocab_size, config.hidden_size), "weight")?, config.hidden_size, ); let layers = (0..config.num_hidden_layers) .map(|index| MistralLayer::load(vb.pp(format!("layers.{index}")), config)) .collect::<Result<Vec<_>>>()?; let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, config.rope_theta, vb.device(), None, )?; let (cos_cache, sin_cache) = get_cos_sin( config.max_position_embeddings, &inv_freqs, vb.dtype(), false, )?; Ok(Self { embeddings, layers, norm, cos_cache, sin_cache, pool, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) } pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> { let _enter = self.span.enter(); let batch_size = batch.cumulative_seq_lengths.len() - 1; let shape = batch.input_ids.len(); // Create Cuda tensors let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; let cu_seqlens = Tensor::from_vec( batch.cumulative_seq_lengths.clone(), batch_size + 1, &self.device, )?; let mut hidden_states = self.embeddings.forward(&input_ids)?; let cos = self.cos_cache.index_select(&position_ids, 0)?; let sin = self.sin_cache.index_select(&position_ids, 0)?; let mut residual = None; for layer in &self.layers { let (h, r) = layer.forward( &hidden_states, residual.as_ref(), &cu_seqlens, &cos, &sin, batch.max_length as usize, )?; hidden_states = h; residual = Some(r); } let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); let pooled_embeddings = if has_pooling_requests { match self.pool { // CLS and LastToken pooling Pool::Cls | Pool::LastToken => { if batch_size > 1 { // Get token indices form cu_seqlens let mut indices = match self.pool { Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, Pool::LastToken => { let end = cu_seqlens.narrow(0, 1, batch_size)?; (&end - &end.ones_like()?)? } _ => unreachable!(), }; // If raw_indices is empty, we don't need to do anything with // the pooled_indices if has_raw_requests { // We need the pooled indices to select the correct cls indices let pooled_indices = Tensor::from_vec( batch.pooled_indices.clone(), batch.pooled_indices.len(), &self.device, )?; // Only select indices that requires pooling indices = indices.index_select(&pooled_indices, 0)? } // Select tokens Some(outputs.index_select(&indices, 0)?) } else { Some( match self.pool { Pool::Cls => outputs.i(0)?, Pool::LastToken => { outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? } _ => unreachable!(), } .unsqueeze(0)?, ) } } // Mean pooling Pool::Mean => { if batch_size > 1 { // for each request that requires pooling let results: Result<Vec<Tensor>> = batch .pooled_indices .into_iter() .map(|i| { let i = i as usize; let start = batch.cumulative_seq_lengths[i]; let len = batch.cumulative_seq_lengths[i + 1] - start; // Mean let embeddings = outputs.narrow(0, start as usize, len as usize)?; embeddings.sum_keepdim(0)? / (len as f64) }) .collect(); // Concatenate all results Some(Tensor::cat(&results?, 0)?) } else { Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) } } Pool::Splade => { unreachable!(); } } } else { None }; let raw_embeddings = if has_raw_requests { if batch_size > 1 && has_pooling_requests { // Create indexing vector for the embeddings let mut final_indices: Vec<u32> = Vec::with_capacity(shape); for i in batch.raw_indices.into_iter() { let i = i as usize; // Get start/end token index of this specific member of the batch let start = batch.cumulative_seq_lengths[i]; let end = batch.cumulative_seq_lengths[i + 1]; for j in start..end { // Add indices for the tokens of this specific member of the batch final_indices.push(j); } } let final_indices_length = final_indices.len(); let final_indices = Tensor::from_vec(final_indices, final_indices_length, &self.device)?; // Select the tokens with final indices Some(outputs.index_select(&final_indices, 0)?) } else { Some(outputs) } } else { None }; Ok((pooled_embeddings, raw_embeddings)) } } impl Model for FlashMistralModel { fn is_padded(&self) -> bool { false } fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> { self.forward(batch) } }