fn get_global_attention_mask()

in backends/candle/src/models/modernbert.rs [553:575]


    fn get_global_attention_mask(
        &self,
        attention_mask: Option<&Tensor>,
        input_shape: &(usize, usize),
    ) -> Result<Tensor> {
        let extended_attention_mask = if let Some(attention_mask) = attention_mask {
            attention_mask.squeeze(2)?
        } else {
            Tensor::ones(*input_shape, DType::F32, &self.device)?
        }
        .unsqueeze(1)?
        .unsqueeze(1)?;

        let (bs, seq_len) = *input_shape;
        let extended_attention_mask = extended_attention_mask.broadcast_as((
            bs,
            self.num_attention_heads,
            seq_len,
            seq_len,
        ))?;

        Ok(extended_attention_mask)
    }