candle-transformers/src/models/pixtral/vision_model.rs (333 lines of code) (raw):

use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; fn default_act() -> candle_nn::Activation { candle_nn::Activation::Silu } fn default_hidden_size() -> usize { 1024 } fn default_intermediate_size() -> usize { 4096 } fn default_num_channels() -> usize { 3 } fn default_num_hidden_layers() -> usize { 24 } fn default_num_attention_heads() -> usize { 16 } #[derive(serde::Deserialize, Debug, Clone)] pub struct Config { #[serde(default = "default_hidden_size")] pub hidden_size: usize, #[serde(default = "default_num_channels")] pub num_channels: usize, pub image_size: usize, pub patch_size: usize, pub rope_theta: f64, #[serde(default = "default_intermediate_size")] pub intermediate_size: usize, #[serde(default = "default_num_hidden_layers")] pub num_hidden_layers: usize, pub head_dim: Option<usize>, #[serde(default = "default_num_attention_heads")] pub num_attention_heads: usize, #[serde(default = "default_act")] pub hidden_act: candle_nn::Activation, } impl Config { pub fn pixtral_12b_2409() -> Self { Self { hidden_size: 1024, num_channels: 3, image_size: 1024, patch_size: 16, rope_theta: 10000.0, intermediate_size: 4096, num_hidden_layers: 24, num_attention_heads: 16, head_dim: None, // Default hidden_act: candle_nn::Activation::Silu, } } fn head_dim(&self) -> usize { self.head_dim .unwrap_or(self.hidden_size / self.num_attention_heads) } } #[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, scale: f64, num_heads: usize, head_dim: usize, } impl Attention { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let h = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let head_dim = cfg.head_dim(); let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?; let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?; let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?; let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?; let scale = (head_dim as f64).powf(-0.5); Ok(Self { q_proj, k_proj, v_proj, o_proj, scale, num_heads, head_dim, }) } fn forward( &self, xs: &Tensor, emb: &RotaryEmbedding, subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result<Tensor> { let (b, patches, _) = xs.dims3()?; let query_states = xs.apply(&self.q_proj)?; let key_states = xs.apply(&self.k_proj)?; let value_states = xs.apply(&self.v_proj)?; let shape = (b, patches, self.num_heads, self.head_dim); let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?; let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?; let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights .matmul(&value_states)? .transpose(1, 2)? .reshape((b, patches, ()))? .apply(&self.o_proj) } } #[derive(Debug, Clone)] struct Mlp { gate_proj: Linear, up_proj: Linear, down_proj: Linear, act_fn: candle_nn::Activation, } impl Mlp { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let (h, i) = (cfg.hidden_size, cfg.intermediate_size); let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?; let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?; let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?; Ok(Self { gate_proj, up_proj, down_proj, act_fn: cfg.hidden_act, }) } } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result<Tensor> { (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))? .apply(&self.down_proj) } } #[derive(Debug, Clone)] struct AttentionLayer { attention_norm: RmsNorm, feed_forward: Mlp, attention: Attention, ffn_norm: RmsNorm, } impl AttentionLayer { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?; let attention = Attention::new(cfg, vb.pp("attention"))?; let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; Ok(Self { attention_norm, feed_forward, attention, ffn_norm, }) } fn forward( &self, xs: &Tensor, emb: &RotaryEmbedding, subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result<Tensor> { let residual = xs; let xs = self.attention.forward( &xs.apply(&self.attention_norm)?, emb, subsampled_positions, attention_mask, )?; let xs = (residual + xs)?; let residual = &xs; let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?; xs + residual } } #[derive(Debug, Clone)] struct Transformer { layers: Vec<AttentionLayer>, } impl Transformer { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb = vb.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?; layers.push(layer) } Ok(Self { layers }) } fn forward( &self, xs: &Tensor, emb: &RotaryEmbedding, subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result<Tensor> { let mut xs = xs.clone(); for layer in self.layers.iter() { xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)? } Ok(xs) } } #[derive(Debug, Clone)] struct RotaryEmbedding { cos: Tensor, sin: Tensor, } impl RotaryEmbedding { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let dtype = vb.dtype(); let dev = vb.device(); let dim = cfg.head_dim(); let rope_theta = cfg.rope_theta as f32; let max_patches_per_side = cfg.image_size / cfg.patch_size; let freqs: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>(); let freqs_h = Tensor::new(freqs_h, dev)?; let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>(); let freqs_w = Tensor::new(freqs_w, dev)?; let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?; let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?; let inv_freq = Tensor::cat( &[ freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?, freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?, ], D::Minus1, )? .reshape(((), dim / 2))?; let cos = inv_freq.cos()?.to_dtype(dtype)?; let sin = inv_freq.sin()?.to_dtype(dtype)?; Ok(Self { cos, sin }) } fn apply_rotary_emb_qkv( &self, q: &Tensor, k: &Tensor, subsampled_positions: Option<&Tensor>, ) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?; let (cos, sin) = match subsampled_positions { None => (&self.cos, &self.sin), Some(pos) => ( &self.cos.index_select(pos, 0)?, &self.sin.index_select(pos, 0)?, ), }; let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?; let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?; Ok((q_embed, k_embed)) } } #[derive(Debug, Clone)] pub struct Model { patch_conv: candle_nn::Conv2d, ln_pre: RmsNorm, transformer: Transformer, patch_positional_embedding: RotaryEmbedding, max_image_width: u32, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let conv2d_cfg = candle_nn::Conv2dConfig { stride: cfg.patch_size, ..Default::default() }; let patch_conv = candle_nn::conv2d_no_bias( cfg.num_channels, cfg.hidden_size, cfg.patch_size, conv2d_cfg, vb.pp("patch_conv"), )?; let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; let transformer = Transformer::new(cfg, vb.pp("transformer"))?; let patch_positional_embedding = RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; let max_image_width = (cfg.image_size / cfg.patch_size) as u32; Ok(Self { patch_conv, ln_pre, transformer, patch_positional_embedding, max_image_width, }) } pub fn position_ids_in_meshgrid( &self, num_patches_h: usize, num_patches_w: usize, device: &Device, ) -> Result<Tensor> { let idx = Tensor::arange(0, num_patches_h as u32, device)?; let idy = Tensor::arange(0, num_patches_w as u32, device)?; let mesh = Tensor::meshgrid(&[idx, idy], false)?; let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?; Ok(ids) } } impl Module for Model { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let patch_embeds = xs.apply(&self.patch_conv)?; let subsampled_positions = Some(self.position_ids_in_meshgrid( patch_embeds.dim(2)?, patch_embeds.dim(3)?, patch_embeds.device(), )?); let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?; self.transformer.forward( &patch_embeds, &self.patch_positional_embedding, subsampled_positions.as_ref(), None, ) } }