crates/ratchet-models/src/whisper/mha.rs (110 lines of code) (raw):

use half::f16; use num::traits::real::Real; use ratchet::{rvec, shape, Tensor}; use ratchet_nn::{KVEntry, Linear, Module}; #[derive(Debug)] pub struct MultiHeadAttention { q: Linear, k: Linear, v: Linear, o: Linear, n_heads: usize, dk: Tensor, } impl MultiHeadAttention { pub fn new(q: Linear, k: Linear, v: Linear, o: Linear, n_heads: usize) -> MultiHeadAttention { let n_state = q.w.shape()[1]; let dk = match q.w.dt().activation_dt() { ratchet::DType::F16 => { let dk = f16::from_f32((n_state / n_heads) as f32); Tensor::from_data( [dk.powf(f16::from_f32(-0.25))], shape![1], q.w.device().clone(), ) } ratchet::DType::F32 => { let dk = (n_state / n_heads) as f32; Tensor::from_data([dk.powf(-0.25)], shape![1], q.w.device().clone()) } _ => unimplemented!(), }; MultiHeadAttention { q, k, v, o, n_heads, dk, } } } #[derive(Debug, derive_new::new)] pub struct MHAInputs { x: Tensor, xa: Option<Tensor>, mask: Option<Tensor>, cache: Option<KVEntry>, is_causal: bool, } impl Module for MultiHeadAttention { type Input = MHAInputs; fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> { let MHAInputs { x, xa, mask, cache, is_causal, } = input; let q = self.q.schedule(x.clone())?; let to_project = xa.unwrap_or(x); let k = self.k.schedule(to_project.clone())?; let v = self.v.schedule(to_project)?; let (k, v) = if let Some(kv) = cache { let prev_entries = kv.entries; let k_cache = kv.k_cache.cache(k, 1, prev_entries)?; let v_cache = kv.v_cache.cache(v, 1, prev_entries)?; (k_cache, v_cache) } else { (k, v) }; self.qkv_attention(q, k, v, mask, is_causal) } } impl MultiHeadAttention { fn qkv_attention( &self, q: Tensor, k: Tensor, v: Tensor, mask: Option<Tensor>, is_causal: bool, ) -> anyhow::Result<Tensor> { let [bs, n_ctx, n_state]: [usize; 3] = q.shape().try_into()?; let [k0, k1, _]: [usize; 3] = k.shape().try_into()?; let [v0, v1, _]: [usize; 3] = v.shape().try_into()?; let q_dt = q.dt(); let hdim = n_state / self.n_heads; let qs = shape![bs, n_ctx, self.n_heads, hdim]; let ks = shape![k0, k1, self.n_heads, hdim]; let vs = shape![v0, v1, self.n_heads, hdim]; let q = q.view(qs)?.permute(&[0, 2, 1, 3])?.mul(self.dk.clone())?; let k = k.view(ks)?.permute(&[0, 2, 3, 1])?.mul(self.dk.clone())?; let v = v.view(vs)?.permute(&[0, 2, 1, 3])?; let mut qk = q.matmul(k, false, false)?; if let Some(m) = mask { let prepared_mask = if is_causal { m.slice(&[0..n_ctx, 0..n_ctx])? } else { m.clone() }; qk = qk.add(prepared_mask)?; } qk = qk.full()?; let w = qk.softmax(3)?.cast(q_dt)?; let s = shape![bs, n_ctx, n_state]; let wv = w.matmul(v, false, false)?.permute(&[0, 2, 1, 3])?.view(s)?; self.o.schedule(wv) } }