fn schedule()

in crates/ratchet-models/src/moondream/vision_encoder.rs [18:64]


    fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
        let h_dim = self.dim / self.n_heads;
        let [b, n, c]: [usize; 3] = input.shape().try_into()?;
        // step 1 - 0, 1, 2, 3, 4
        // step 2 - 0, 2, 1, 3, 4
        // step 3 - 2, 0, 1, 3, 4
        // step 4 - 2, 0, 3, 1, 4

        // b, n, 3, nh, hd
        let mut qkv = self.qkv.schedule(input.clone())?;
        // b, 3, n, nh, hd
        qkv = qkv
            .view(shape![b, n, 3, self.n_heads * h_dim])?
            .permute(&[0, 2, 1, 3])?;
        // 3, b, n, nh, hd
        qkv = qkv
            .view(shape![b, 3, n * self.n_heads * h_dim])?
            .permute(&[1, 0, 2])?;
        // 3, b, nh, n, hd
        qkv = qkv
            .view(shape![3 * b, n, self.n_heads, h_dim])?
            .permute(&[0, 2, 1, 3])?
            .view(shape![3, b * self.n_heads * n * h_dim])?;

        let q = qkv
            .clone()
            .slice(&[0..1, 0..(b * self.n_heads * n * h_dim)])?
            .view(shape![b, self.n_heads, n, h_dim])?;
        let k = qkv
            .clone()
            .slice(&[1..2, 0..(b * self.n_heads * n * h_dim)])?
            .view(shape![b, self.n_heads, n, h_dim])?;
        let v = qkv
            .clone()
            .slice(&[2..3, 0..(b * self.n_heads * n * h_dim)])?
            .view(shape![b, self.n_heads, n, h_dim])?;

        // scaled dot-product attention
        let mut attn_weights = q
            .full()?
            .matmul(k.permute(&[0, 1, 3, 2])?.full()?, false, false)?
            .mul(self.scale_factor.clone())?;
        attn_weights = attn_weights.softmax(3)?.cast(v.dt())?;
        let mut x = attn_weights.matmul(v, false, false)?;
        x = x.permute(&[0, 2, 1, 3])?.view(shape![b, n, c])?;
        self.proj.schedule(x)
    }