fn forward()

in backends/candle/src/models/modernbert.rs [180:270]


    fn forward(
        &self,
        hidden_states: &Tensor,
        attention_mask: &Tensor,
        rotary_cache: &(Tensor, Tensor),
    ) -> Result<Tensor> {
        let _enter = self.span.enter();
        let device = hidden_states.device();

        let qkv = self.wqkv.forward(hidden_states)?;

        let mut new_qkv_shape = qkv.dims().to_vec();
        new_qkv_shape.pop();
        new_qkv_shape.push(self.num_attention_heads * 3);
        new_qkv_shape.push(self.attention_head_size);
        let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;

        let qkv = qkv.chunk(3, 1)?;
        let query_layer = &qkv[0].contiguous()?;
        let key_layer = &qkv[1].contiguous()?;
        let value_layer = &qkv[2];

        let query_layer = apply_rotary(
            query_layer,
            &rotary_cache.0,
            &rotary_cache.1,
            self.attention_head_size,
        )?;
        let key_layer = apply_rotary(
            key_layer,
            &rotary_cache.0,
            &rotary_cache.1,
            self.attention_head_size,
        )?;

        #[allow(unused_variables)]
        let context_layer =
            if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) {
                #[cfg(feature = "cuda")]
                {
                    let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?;
                    let key_layer = key_layer.flatten(0, 1)?;
                    let query_layer = query_layer.flatten(0, 1)?;
                    let value_layer = value_layer.flatten(0, 1)?;
                    let attention_mask = attention_mask.flatten(0, 1)?;

                    let attention_scores = cublaslt.batch_matmul(
                        &key_layer,
                        &query_layer,
                        Some(attention_mask.as_ref()),
                        Some(self.softmax_scale as f32),
                        None,
                        None,
                        None,
                    )?;
                    let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

                    let context_layer = cublaslt.batch_matmul(
                        &value_layer.t()?.contiguous()?,
                        &attention_probs,
                        Some(&query_layer),
                        None,
                        None,
                        None,
                        None,
                    )?;

                    context_layer.reshape((
                        batch_size,
                        self.num_attention_heads,
                        seq_len,
                        self.attention_head_size,
                    ))
                }
                #[cfg(not(feature = "cuda"))]
                {
                    candle::bail!("`cuda` feature is not enabled")
                }
            } else {
                let attn_weights = query_layer.matmul(&key_layer.t()?)?;
                let attn_weights = (attn_weights * self.softmax_scale)?;
                let attn_weights = attn_weights.add(attention_mask)?;
                let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
                attn_weights.matmul(&value_layer.contiguous()?)
            }?;

        let hidden_states = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
        let hidden_states = self.wo.forward(&hidden_states)?;

        Ok(hidden_states)
    }