in backends/candle/src/models/modernbert.rs [180:270]
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
rotary_cache: &(Tensor, Tensor),
) -> Result<Tensor> {
let _enter = self.span.enter();
let device = hidden_states.device();
let qkv = self.wqkv.forward(hidden_states)?;
let mut new_qkv_shape = qkv.dims().to_vec();
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let qkv = qkv.chunk(3, 1)?;
let query_layer = &qkv[0].contiguous()?;
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];
let query_layer = apply_rotary(
query_layer,
&rotary_cache.0,
&rotary_cache.1,
self.attention_head_size,
)?;
let key_layer = apply_rotary(
key_layer,
&rotary_cache.0,
&rotary_cache.1,
self.attention_head_size,
)?;
#[allow(unused_variables)]
let context_layer =
if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) {
#[cfg(feature = "cuda")]
{
let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?;
let key_layer = key_layer.flatten(0, 1)?;
let query_layer = query_layer.flatten(0, 1)?;
let value_layer = value_layer.flatten(0, 1)?;
let attention_mask = attention_mask.flatten(0, 1)?;
let attention_scores = cublaslt.batch_matmul(
&key_layer,
&query_layer,
Some(attention_mask.as_ref()),
Some(self.softmax_scale as f32),
None,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let context_layer = cublaslt.batch_matmul(
&value_layer.t()?.contiguous()?,
&attention_probs,
Some(&query_layer),
None,
None,
None,
None,
)?;
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attn_weights = query_layer.matmul(&key_layer.t()?)?;
let attn_weights = (attn_weights * self.softmax_scale)?;
let attn_weights = attn_weights.add(attention_mask)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_layer.contiguous()?)
}?;
let hidden_states = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
let hidden_states = self.wo.forward(&hidden_states)?;
Ok(hidden_states)
}