in crates/ratchet-models/src/whisper/mha.rs [86:127]
fn qkv_attention(
&self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Option<Tensor>,
is_causal: bool,
) -> anyhow::Result<Tensor> {
let [bs, n_ctx, n_state]: [usize; 3] = q.shape().try_into()?;
let [k0, k1, _]: [usize; 3] = k.shape().try_into()?;
let [v0, v1, _]: [usize; 3] = v.shape().try_into()?;
let q_dt = q.dt();
let hdim = n_state / self.n_heads;
let qs = shape![bs, n_ctx, self.n_heads, hdim];
let ks = shape![k0, k1, self.n_heads, hdim];
let vs = shape![v0, v1, self.n_heads, hdim];
let q = q.view(qs)?.permute(&[0, 2, 1, 3])?.mul(self.dk.clone())?;
let k = k.view(ks)?.permute(&[0, 2, 3, 1])?.mul(self.dk.clone())?;
let v = v.view(vs)?.permute(&[0, 2, 1, 3])?;
let mut qk = q.matmul(k, false, false)?;
if let Some(m) = mask {
let prepared_mask = if is_causal {
m.slice(&[0..n_ctx, 0..n_ctx])?
} else {
m.clone()
};
qk = qk.add(prepared_mask)?;
}
qk = qk.full()?;
let w = qk.softmax(3)?.cast(q_dt)?;
let s = shape![bs, n_ctx, n_state];
let wv = w.matmul(v, false, false)?.permute(&[0, 2, 1, 3])?.view(s)?;
self.o.schedule(wv)
}