fn forward()

in candle-transformers/src/models/glm4.rs [254:349]


    fn forward(
        &mut self,
        xs: &Tensor,
        attention_mask: &Option<Tensor>,
        rotary_emb: &RotaryEmbedding,
    ) -> Result<Tensor> {
        let mixed_x_layer = xs.apply(&self.query_key_value)?;
        if !self.multi_query_attention {
            candle::bail!("only multi_query_attention=true is supported")
        }
        let hpa = self.hidden_size_per_attention_head;
        let query_layer =
            mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
        let key_layer = mixed_x_layer.narrow(
            D::Minus1,
            self.num_attention_heads_per_partition * hpa,
            self.num_multi_query_groups_per_partition * hpa,
        )?;
        let value_layer = mixed_x_layer.narrow(
            D::Minus1,
            self.num_attention_heads_per_partition * hpa
                + self.num_multi_query_groups_per_partition * hpa,
            self.num_multi_query_groups_per_partition * hpa,
        )?;
        let query_layer = query_layer.reshape((
            query_layer.dim(0)?,
            query_layer.dim(1)?,
            self.num_attention_heads_per_partition,
            hpa,
        ))?;
        let key_layer = key_layer.reshape((
            key_layer.dim(0)?,
            key_layer.dim(1)?,
            self.num_multi_query_groups_per_partition,
            hpa,
        ))?;
        let value_layer = value_layer.reshape((
            value_layer.dim(0)?,
            value_layer.dim(1)?,
            self.num_multi_query_groups_per_partition,
            hpa,
        ))?;

        // Rotary embeddings.
        let seqlen_offset = match &self.kv_cache {
            None => 0,
            Some((prev_k, _)) => prev_k.dim(0)?,
        };
        let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
        let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;

        // KV cache.
        let (key_layer, value_layer) = match &self.kv_cache {
            None => (key_layer, value_layer),
            Some((prev_k, prev_v)) => {
                let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
                let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
                (k, v)
            }
        };
        self.kv_cache = Some((key_layer.clone(), value_layer.clone()));

        // Repeat KV.
        let ratio =
            self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
        let key_layer = {
            let (d0, d1, d2, d3) = key_layer.dims4()?;
            key_layer
                .unsqueeze(D::Minus2)?
                .expand((d0, d1, d2, ratio, d3))?
                .reshape((
                    d0,
                    d1,
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                ))?
        };
        let value_layer = {
            let (d0, d1, d2, d3) = value_layer.dims4()?;
            value_layer
                .unsqueeze(D::Minus2)?
                .expand((d0, d1, d2, ratio, d3))?
                .reshape((
                    d0,
                    d1,
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                ))?
        };

        let context_layer =
            self.core_attention
                .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
        let output = context_layer.apply(&self.dense)?;
        Ok(output)
    }