in crates/ratchet-models/src/phi3/attn.rs [88:159]
fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let PhiAttnInput { input, mask, cache } = input;
let [batch_size, q_len, n_state]: [usize; 3] = input.shape().try_into()?;
let hdim = n_state / self.n_heads as usize;
let kv_x_hdim = self.n_kv_heads as usize * hdim;
let qkv = self.qkv.schedule(input)?;
let query_pos = self.n_heads as usize * hdim;
let key_pos = query_pos + kv_x_hdim;
let value_pos = key_pos + kv_x_hdim;
let query_states = qkv
.clone()
.slice(&[0..batch_size, 0..q_len, 0..query_pos])?;
let key_states = qkv
.clone()
.slice(&[0..batch_size, 0..q_len, query_pos..key_pos])?;
let value_states = qkv
.clone()
.slice(&[0..batch_size, 0..q_len, key_pos..value_pos])?;
let q_shape = shape![batch_size as _, q_len, self.n_heads as _, hdim];
let kv_shape = shape![batch_size as _, q_len, self.n_kv_heads as _, hdim];
let query_states = query_states.view(q_shape)?.permute(&[0, 2, 1, 3])?;
let key_states = key_states.view(kv_shape.clone())?.permute(&[0, 2, 1, 3])?;
let value_states = value_states.view(kv_shape)?.permute(&[0, 2, 1, 3])?;
let offset = cache.as_ref().map(|kv| kv.entries).unwrap_or(0);
let q_dt = query_states.dt();
let query_states = self
.rope
.schedule(RotaryInput {
input: query_states.full()?,
offset,
})?
.cast(q_dt)?;
let key_states = self
.rope
.schedule(RotaryInput {
input: key_states.full()?,
offset,
})?
.cast(q_dt)?;
let (key_states, value_states) = if let Some(kv) = cache {
let k_cache = kv.k_cache.cache(key_states, 2, offset)?;
let v_cache = kv.v_cache.cache(value_states, 2, offset)?;
(k_cache, v_cache)
} else {
(key_states, value_states)
};
let mut attn_weights = query_states
.full()?
.matmul(key_states.full()?, false, true)?
.mul(self.softmax_scale.clone())?
.cast(q_dt)?;
if let Some(m) = mask {
let attn_dt = attn_weights.dt();
attn_weights = attn_weights.add(m.cast(attn_dt)?)?;
}
let w = attn_weights.full()?.softmax(3)?.cast(value_states.dt())?;
let wv = w
.matmul(value_states, false, false)?
.permute(&[0, 2, 1, 3])?;
let wv = wv.view(shape![batch_size as _, q_len, n_state])?;
self.o.schedule(wv)
}