in crates/ratchet-models/src/moondream/vision_encoder.rs [18:64]
fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let h_dim = self.dim / self.n_heads;
let [b, n, c]: [usize; 3] = input.shape().try_into()?;
// step 1 - 0, 1, 2, 3, 4
// step 2 - 0, 2, 1, 3, 4
// step 3 - 2, 0, 1, 3, 4
// step 4 - 2, 0, 3, 1, 4
// b, n, 3, nh, hd
let mut qkv = self.qkv.schedule(input.clone())?;
// b, 3, n, nh, hd
qkv = qkv
.view(shape![b, n, 3, self.n_heads * h_dim])?
.permute(&[0, 2, 1, 3])?;
// 3, b, n, nh, hd
qkv = qkv
.view(shape![b, 3, n * self.n_heads * h_dim])?
.permute(&[1, 0, 2])?;
// 3, b, nh, n, hd
qkv = qkv
.view(shape![3 * b, n, self.n_heads, h_dim])?
.permute(&[0, 2, 1, 3])?
.view(shape![3, b * self.n_heads * n * h_dim])?;
let q = qkv
.clone()
.slice(&[0..1, 0..(b * self.n_heads * n * h_dim)])?
.view(shape![b, self.n_heads, n, h_dim])?;
let k = qkv
.clone()
.slice(&[1..2, 0..(b * self.n_heads * n * h_dim)])?
.view(shape![b, self.n_heads, n, h_dim])?;
let v = qkv
.clone()
.slice(&[2..3, 0..(b * self.n_heads * n * h_dim)])?
.view(shape![b, self.n_heads, n, h_dim])?;
// scaled dot-product attention
let mut attn_weights = q
.full()?
.matmul(k.permute(&[0, 1, 3, 2])?.full()?, false, false)?
.mul(self.scale_factor.clone())?;
attn_weights = attn_weights.softmax(3)?.cast(v.dt())?;
let mut x = attn_weights.matmul(v, false, false)?;
x = x.permute(&[0, 2, 1, 3])?.view(shape![b, n, c])?;
self.proj.schedule(x)
}