fn schedule()

in crates/ratchet-models/src/phi3/attn.rs [88:159]


    fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
        let PhiAttnInput { input, mask, cache } = input;
        let [batch_size, q_len, n_state]: [usize; 3] = input.shape().try_into()?;

        let hdim = n_state / self.n_heads as usize;
        let kv_x_hdim = self.n_kv_heads as usize * hdim;

        let qkv = self.qkv.schedule(input)?;
        let query_pos = self.n_heads as usize * hdim;
        let key_pos = query_pos + kv_x_hdim;
        let value_pos = key_pos + kv_x_hdim;

        let query_states = qkv
            .clone()
            .slice(&[0..batch_size, 0..q_len, 0..query_pos])?;
        let key_states = qkv
            .clone()
            .slice(&[0..batch_size, 0..q_len, query_pos..key_pos])?;
        let value_states = qkv
            .clone()
            .slice(&[0..batch_size, 0..q_len, key_pos..value_pos])?;

        let q_shape = shape![batch_size as _, q_len, self.n_heads as _, hdim];
        let kv_shape = shape![batch_size as _, q_len, self.n_kv_heads as _, hdim];

        let query_states = query_states.view(q_shape)?.permute(&[0, 2, 1, 3])?;
        let key_states = key_states.view(kv_shape.clone())?.permute(&[0, 2, 1, 3])?;
        let value_states = value_states.view(kv_shape)?.permute(&[0, 2, 1, 3])?;

        let offset = cache.as_ref().map(|kv| kv.entries).unwrap_or(0);
        let q_dt = query_states.dt();
        let query_states = self
            .rope
            .schedule(RotaryInput {
                input: query_states.full()?,
                offset,
            })?
            .cast(q_dt)?;
        let key_states = self
            .rope
            .schedule(RotaryInput {
                input: key_states.full()?,
                offset,
            })?
            .cast(q_dt)?;

        let (key_states, value_states) = if let Some(kv) = cache {
            let k_cache = kv.k_cache.cache(key_states, 2, offset)?;
            let v_cache = kv.v_cache.cache(value_states, 2, offset)?;
            (k_cache, v_cache)
        } else {
            (key_states, value_states)
        };

        let mut attn_weights = query_states
            .full()?
            .matmul(key_states.full()?, false, true)?
            .mul(self.softmax_scale.clone())?
            .cast(q_dt)?;

        if let Some(m) = mask {
            let attn_dt = attn_weights.dt();
            attn_weights = attn_weights.add(m.cast(attn_dt)?)?;
        }

        let w = attn_weights.full()?.softmax(3)?.cast(value_states.dt())?;
        let wv = w
            .matmul(value_states, false, false)?
            .permute(&[0, 2, 1, 3])?;
        let wv = wv.view(shape![batch_size as _, q_len, n_state])?;
        self.o.schedule(wv)
    }