in crates/ratchet-models/src/phi2/attn.rs [101:153]
fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let PhiAttnInput { input, mask, cache } = input;
let [batch_size, seq_len, n_state]: [usize; 3] = input.shape().try_into()?;
let q = self.q.schedule(input.clone())?;
let k = self.k.schedule(input.clone())?;
let v = self.v.schedule(input)?;
let h_dim = n_state / self.n_heads as usize;
//TODO:
//if self.qk_layer_norm { ... }
let q_shape = shape![batch_size as _, seq_len, self.n_heads as _, h_dim];
let kv_shape = shape![batch_size as _, seq_len, self.n_kv_heads as _, h_dim];
let query_states = q.view(q_shape)?.permute(&[0, 2, 1, 3])?;
let key_states = k.view(kv_shape.clone())?.permute(&[0, 2, 1, 3])?;
let value_states = v.view(kv_shape)?.permute(&[0, 2, 1, 3])?;
let offset = cache.as_ref().map(|kv| kv.entries).unwrap_or(0);
let query_states = self.rope.schedule(RotaryInput {
input: query_states,
offset,
})?;
let key_states = self.rope.schedule(RotaryInput {
input: key_states,
offset,
})?;
let (key_states, value_states) = if let Some(kv) = cache {
let k_cache = kv.k_cache.cache(key_states, 2, offset)?;
let v_cache = kv.v_cache.cache(value_states, 2, offset)?;
(k_cache, v_cache)
} else {
(key_states, value_states)
};
//TODO: can we just use the built in transposed matmul?
let mut attn_weights = query_states
.full()?
.matmul(key_states.permute(&[0, 1, 3, 2])?.full()?, false, false)?
.mul(self.softmax_scale.clone())?;
if let Some(m) = mask {
attn_weights = attn_weights.add(m)?;
}
let w = attn_weights.softmax(3)?.cast(value_states.dt())?;
let wv = w
.matmul(value_states, false, false)?
.permute(&[0, 2, 1, 3])?;
let wv = wv.view(shape![batch_size as _, seq_len, n_state])?;
self.o.schedule(wv)
}