fn forward()

in backends/candle/src/models/modernbert.rs [605:803]


    fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
        let _enter = self.span.enter();

        let batch_size = batch.len();
        let max_length = batch.max_length as usize;

        let shape = (batch_size, max_length);

        let (input_ids, input_lengths, position_ids, attention_mask) = if batch_size > 1 {
            let elems = batch_size * max_length;

            let mut input_ids = Vec::with_capacity(elems);
            let mut position_ids = Vec::with_capacity(elems);
            let mut attention_mask = Vec::with_capacity(elems);
            let mut input_lengths = Vec::with_capacity(batch_size);

            let mut masking = false;

            for i in 0..batch_size {
                let start = batch.cumulative_seq_lengths[i] as usize;
                let end = batch.cumulative_seq_lengths[i + 1] as usize;
                let seq_length = (end - start) as u32;
                input_lengths.push(seq_length as f32);

                for j in start..end {
                    input_ids.push(batch.input_ids[j]);
                    position_ids.push(batch.position_ids[j]);
                    attention_mask.push(1.0_f32);
                }

                let padding = batch.max_length - seq_length;
                if padding > 0 {
                    masking = true;
                    for _ in 0..padding {
                        input_ids.push(self.pad_token_id);
                        position_ids.push(0);
                        attention_mask.push(0.0_f32);
                    }
                }
            }

            let attention_mask = match masking {
                true => {
                    let attention_mask = Tensor::from_vec(
                        attention_mask,
                        (batch_size, max_length, 1),
                        &self.device,
                    )?
                    .to_dtype(self.dtype)?;

                    Some(attention_mask)
                }
                false => None,
            };

            (input_ids, input_lengths, position_ids, attention_mask)
        } else {
            (
                batch.input_ids,
                vec![max_length as f32],
                batch.position_ids,
                None,
            )
        };

        let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
        let position_ids = Tensor::from_vec(position_ids, batch_size * max_length, &self.device)?;
        let mut input_lengths =
            Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;

        let global_attention_mask = self
            .get_global_attention_mask(attention_mask.as_ref(), &shape)?
            .to_dtype(self.dtype)?;
        let local_attention_mask = self
            .get_local_attention_mask(&global_attention_mask)?
            .to_dtype(self.dtype)?;

        let min_value = match self.dtype {
            DType::F32 => f32::MIN as f64,
            _ => -65504.0, // f16 minimum value
        };

        let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
        let local_attention_mask = ((1.0 - local_attention_mask)? * min_value)?;

        let global_rotary_cache =
            get_cos_sin(max_length, &self.global_inv_freqs, self.dtype, true)?;
        let local_rotary_cache = get_cos_sin(max_length, &self.local_inv_freqs, self.dtype, true)?;

        let global_rotary_cache = (
            global_rotary_cache
                .0
                .index_select(&position_ids, 0)?
                .reshape((batch_size, 1, max_length, self.rotary_dim))?,
            global_rotary_cache
                .1
                .index_select(&position_ids, 0)?
                .reshape((batch_size, 1, max_length, self.rotary_dim))?,
        );

        let local_rotary_cache = (
            local_rotary_cache
                .0
                .index_select(&position_ids, 0)?
                .reshape((batch_size, 1, max_length, self.rotary_dim))?,
            local_rotary_cache
                .1
                .index_select(&position_ids, 0)?
                .reshape((batch_size, 1, max_length, self.rotary_dim))?,
        );

        let hidden_states = self.embeddings.forward(&input_ids)?;

        let hidden_states = self.encoder.forward(
            &hidden_states,
            &global_attention_mask,
            &local_attention_mask,
            &global_rotary_cache,
            &local_rotary_cache,
        )?;
        let outputs = self.final_norm.forward(&hidden_states, None)?;

        let has_pooling_requests = !batch.pooled_indices.is_empty();
        let has_raw_requests = !batch.raw_indices.is_empty();

        let pooled_embeddings = if has_pooling_requests {
            let pooled_indices_length = batch.pooled_indices.len();
            let mut outputs = outputs.clone();

            let pooled_indices = if has_raw_requests {
                let pooled_indices =
                    Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?;

                outputs = outputs.index_select(&pooled_indices, 0)?;
                Some(pooled_indices)
            } else {
                None
            };

            let pooled_embeddings = match self.pool {
                Pool::Cls => outputs.i((.., 0))?,
                Pool::LastToken | Pool::Splade => unreachable!(),
                Pool::Mean => {
                    if let Some(ref attention_mask) = attention_mask {
                        let mut attention_mask = attention_mask.clone();

                        if let Some(pooled_indices) = pooled_indices {
                            attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
                            input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
                        };

                        outputs = outputs.broadcast_mul(&attention_mask)?;
                    }

                    (outputs.sum(1)?.broadcast_div(&input_lengths))?
                }
            };
            Some(pooled_embeddings)
        } else {
            None
        };

        let raw_embeddings = if has_raw_requests {
            let (b, l, h) = outputs.shape().dims3()?;
            let outputs = outputs.reshape((b * l, h))?;

            // We need to remove the padding tokens only if batch_size > 1 and there are some
            // member of the batch that require pooling
            // or if batch_size > 1 and the members of the batch have different lengths
            if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 {
                let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length);

                for i in batch.raw_indices.into_iter() {
                    let start = i * batch.max_length;
                    let i = i as usize;
                    let length =
                        batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];

                    for j in start..start + length {
                        // Add indices for the tokens of this specific member of the batch
                        final_indices.push(j);
                    }
                }

                let final_indices_length = final_indices.len();
                let final_indices =
                    Tensor::from_vec(final_indices, final_indices_length, &self.device)?;

                // Select the tokens with final indices
                Some(outputs.index_select(&final_indices, 0)?)
            } else {
                Some(outputs)
            }
        } else {
            None
        };

        Ok((pooled_embeddings, raw_embeddings))
    }