in backends/candle/src/models/modernbert.rs [605:803]
fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let _enter = self.span.enter();
let batch_size = batch.len();
let max_length = batch.max_length as usize;
let shape = (batch_size, max_length);
let (input_ids, input_lengths, position_ids, attention_mask) = if batch_size > 1 {
let elems = batch_size * max_length;
let mut input_ids = Vec::with_capacity(elems);
let mut position_ids = Vec::with_capacity(elems);
let mut attention_mask = Vec::with_capacity(elems);
let mut input_lengths = Vec::with_capacity(batch_size);
let mut masking = false;
for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = (end - start) as u32;
input_lengths.push(seq_length as f32);
for j in start..end {
input_ids.push(batch.input_ids[j]);
position_ids.push(batch.position_ids[j]);
attention_mask.push(1.0_f32);
}
let padding = batch.max_length - seq_length;
if padding > 0 {
masking = true;
for _ in 0..padding {
input_ids.push(self.pad_token_id);
position_ids.push(0);
attention_mask.push(0.0_f32);
}
}
}
let attention_mask = match masking {
true => {
let attention_mask = Tensor::from_vec(
attention_mask,
(batch_size, max_length, 1),
&self.device,
)?
.to_dtype(self.dtype)?;
Some(attention_mask)
}
false => None,
};
(input_ids, input_lengths, position_ids, attention_mask)
} else {
(
batch.input_ids,
vec![max_length as f32],
batch.position_ids,
None,
)
};
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(position_ids, batch_size * max_length, &self.device)?;
let mut input_lengths =
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
let global_attention_mask = self
.get_global_attention_mask(attention_mask.as_ref(), &shape)?
.to_dtype(self.dtype)?;
let local_attention_mask = self
.get_local_attention_mask(&global_attention_mask)?
.to_dtype(self.dtype)?;
let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
};
let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
let local_attention_mask = ((1.0 - local_attention_mask)? * min_value)?;
let global_rotary_cache =
get_cos_sin(max_length, &self.global_inv_freqs, self.dtype, true)?;
let local_rotary_cache = get_cos_sin(max_length, &self.local_inv_freqs, self.dtype, true)?;
let global_rotary_cache = (
global_rotary_cache
.0
.index_select(&position_ids, 0)?
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
global_rotary_cache
.1
.index_select(&position_ids, 0)?
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
);
let local_rotary_cache = (
local_rotary_cache
.0
.index_select(&position_ids, 0)?
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
local_rotary_cache
.1
.index_select(&position_ids, 0)?
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
);
let hidden_states = self.embeddings.forward(&input_ids)?;
let hidden_states = self.encoder.forward(
&hidden_states,
&global_attention_mask,
&local_attention_mask,
&global_rotary_cache,
&local_rotary_cache,
)?;
let outputs = self.final_norm.forward(&hidden_states, None)?;
let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();
let pooled_embeddings = if has_pooling_requests {
let pooled_indices_length = batch.pooled_indices.len();
let mut outputs = outputs.clone();
let pooled_indices = if has_raw_requests {
let pooled_indices =
Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?;
outputs = outputs.index_select(&pooled_indices, 0)?;
Some(pooled_indices)
} else {
None
};
let pooled_embeddings = match self.pool {
Pool::Cls => outputs.i((.., 0))?,
Pool::LastToken | Pool::Splade => unreachable!(),
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
let mut attention_mask = attention_mask.clone();
if let Some(pooled_indices) = pooled_indices {
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
};
outputs = outputs.broadcast_mul(&attention_mask)?;
}
(outputs.sum(1)?.broadcast_div(&input_lengths))?
}
};
Some(pooled_embeddings)
} else {
None
};
let raw_embeddings = if has_raw_requests {
let (b, l, h) = outputs.shape().dims3()?;
let outputs = outputs.reshape((b * l, h))?;
// We need to remove the padding tokens only if batch_size > 1 and there are some
// member of the batch that require pooling
// or if batch_size > 1 and the members of the batch have different lengths
if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 {
let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length);
for i in batch.raw_indices.into_iter() {
let start = i * batch.max_length;
let i = i as usize;
let length =
batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];
for j in start..start + length {
// Add indices for the tokens of this specific member of the batch
final_indices.push(j);
}
}
let final_indices_length = final_indices.len();
let final_indices =
Tensor::from_vec(final_indices, final_indices_length, &self.device)?;
// Select the tokens with final indices
Some(outputs.index_select(&final_indices, 0)?)
} else {
Some(outputs)
}
} else {
None
};
Ok((pooled_embeddings, raw_embeddings))
}