in backends/candle/src/models/flash_modernbert.rs [161:188]
fn forward(
&self,
hidden_states: &Tensor,
cu_seqlens: &Tensor,
cos: &Tensor,
sin: &Tensor,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = hidden_states.clone();
let attn_norm = if let Some(attn_norm) = &self.attn_norm {
attn_norm.forward(hidden_states, None)?
} else {
hidden_states.clone()
};
let attn_outputs = self.attn.forward(&attn_norm, cu_seqlens, cos, sin, max_s)?;
let hidden_states = residual.add(&attn_outputs)?;
let mlp_output = self
.mlp
.forward(&self.mlp_norm.forward(&hidden_states, None)?)?;
hidden_states.add(&mlp_output)
}