fn qkv_attention()

in crates/ratchet-models/src/whisper/mha.rs [86:127]


    fn qkv_attention(
        &self,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Option<Tensor>,
        is_causal: bool,
    ) -> anyhow::Result<Tensor> {
        let [bs, n_ctx, n_state]: [usize; 3] = q.shape().try_into()?;
        let [k0, k1, _]: [usize; 3] = k.shape().try_into()?;
        let [v0, v1, _]: [usize; 3] = v.shape().try_into()?;
        let q_dt = q.dt();

        let hdim = n_state / self.n_heads;

        let qs = shape![bs, n_ctx, self.n_heads, hdim];
        let ks = shape![k0, k1, self.n_heads, hdim];
        let vs = shape![v0, v1, self.n_heads, hdim];

        let q = q.view(qs)?.permute(&[0, 2, 1, 3])?.mul(self.dk.clone())?;
        let k = k.view(ks)?.permute(&[0, 2, 3, 1])?.mul(self.dk.clone())?;
        let v = v.view(vs)?.permute(&[0, 2, 1, 3])?;

        let mut qk = q.matmul(k, false, false)?;

        if let Some(m) = mask {
            let prepared_mask = if is_causal {
                m.slice(&[0..n_ctx, 0..n_ctx])?
            } else {
                m.clone()
            };
            qk = qk.add(prepared_mask)?;
        }
        qk = qk.full()?;

        let w = qk.softmax(3)?.cast(q_dt)?;

        let s = shape![bs, n_ctx, n_state];
        let wv = w.matmul(v, false, false)?.permute(&[0, 2, 1, 3])?.view(s)?;

        self.o.schedule(wv)
    }