fn schedule()

in crates/ratchet-models/src/phi2/attn.rs [101:153]


    fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
        let PhiAttnInput { input, mask, cache } = input;
        let [batch_size, seq_len, n_state]: [usize; 3] = input.shape().try_into()?;
        let q = self.q.schedule(input.clone())?;
        let k = self.k.schedule(input.clone())?;
        let v = self.v.schedule(input)?;

        let h_dim = n_state / self.n_heads as usize;

        //TODO:
        //if self.qk_layer_norm { ... }

        let q_shape = shape![batch_size as _, seq_len, self.n_heads as _, h_dim];
        let kv_shape = shape![batch_size as _, seq_len, self.n_kv_heads as _, h_dim];
        let query_states = q.view(q_shape)?.permute(&[0, 2, 1, 3])?;
        let key_states = k.view(kv_shape.clone())?.permute(&[0, 2, 1, 3])?;
        let value_states = v.view(kv_shape)?.permute(&[0, 2, 1, 3])?;

        let offset = cache.as_ref().map(|kv| kv.entries).unwrap_or(0);
        let query_states = self.rope.schedule(RotaryInput {
            input: query_states,
            offset,
        })?;
        let key_states = self.rope.schedule(RotaryInput {
            input: key_states,
            offset,
        })?;

        let (key_states, value_states) = if let Some(kv) = cache {
            let k_cache = kv.k_cache.cache(key_states, 2, offset)?;
            let v_cache = kv.v_cache.cache(value_states, 2, offset)?;
            (k_cache, v_cache)
        } else {
            (key_states, value_states)
        };

        //TODO: can we just use the built in transposed matmul?
        let mut attn_weights = query_states
            .full()?
            .matmul(key_states.permute(&[0, 1, 3, 2])?.full()?, false, false)?
            .mul(self.softmax_scale.clone())?;

        if let Some(m) = mask {
            attn_weights = attn_weights.add(m)?;
        }

        let w = attn_weights.softmax(3)?.cast(value_states.dt())?;
        let wv = w
            .matmul(value_states, false, false)?
            .permute(&[0, 2, 1, 3])?;
        let wv = wv.view(shape![batch_size as _, seq_len, n_state])?;
        self.o.schedule(wv)
    }