fn disentangled_attention_bias()

in candle-transformers/src/models/debertav2.rs [544:700]


    fn disentangled_attention_bias(
        &self,
        query_layer: Tensor,
        key_layer: Tensor,
        relative_pos: Option<&Tensor>,
        rel_embeddings: Tensor,
        scale_factor: usize,
    ) -> Result<Tensor> {
        let mut relative_pos = relative_pos.map_or(
            build_relative_position(
                query_layer.dim(D::Minus2)?,
                key_layer.dim(D::Minus2)?,
                &self.device,
                Some(self.position_buckets),
                Some(self.max_relative_positions),
            )?,
            |pos| pos.clone(),
        );

        relative_pos = match relative_pos.dims().len() {
            2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?,
            3 => relative_pos.unsqueeze(1)?,
            other => {
                bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}")
            }
        };

        let att_span = self.pos_ebd_size;

        let rel_embeddings = rel_embeddings
            .narrow(0, 0, (att_span * 2) as usize)?
            .unsqueeze(0)?;

        let mut pos_query_layer: Option<Tensor> = None;
        let mut pos_key_layer: Option<Tensor> = None;

        let repeat_with = query_layer.dim(0)? / self.num_attention_heads;
        if self.share_att_key {
            pos_query_layer = Some(
                self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)?
                    .repeat(repeat_with)?,
            );

            pos_key_layer = Some(
                self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)?
                    .repeat(repeat_with)?,
            )
        } else {
            if self.config.pos_att_type.iter().any(|s| s == "c2p") {
                pos_key_layer = Some(
                    self.transpose_for_scores(
                        &self
                            .pos_key_proj
                            .as_ref()
                            .context(
                                "Need pos_key_proj when share_att_key is false or not specified",
                            )?
                            .forward(&rel_embeddings)?,
                    )?
                    .repeat(repeat_with)?,
                )
            }
            if self.config.pos_att_type.iter().any(|s| s == "p2c") {
                pos_query_layer = Some(self.transpose_for_scores(&self
                    .pos_query_proj
                    .as_ref()
                    .context("Need a pos_query_proj when share_att_key is false or not specified")?
                    .forward(&rel_embeddings)?)?.repeat(repeat_with)?)
            }
        }

        let mut score = Tensor::new(&[0 as f32], &self.device)?;

        if self.config.pos_att_type.iter().any(|s| s == "c2p") {
            let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?;

            let scale = Tensor::new(
                &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32],
                &self.device,
            )?
            .sqrt()?;

            let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?;

            let c2p_pos = relative_pos
                .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)?
                .clamp(0 as f32, (att_span * 2 - 1) as f32)?;

            c2p_att = c2p_att.gather(
                &c2p_pos
                    .squeeze(0)?
                    .expand(&[
                        query_layer.dim(0)?,
                        query_layer.dim(1)?,
                        relative_pos.dim(D::Minus1)?,
                    ])?
                    .contiguous()?,
                D::Minus1,
            )?;

            score = score.broadcast_add(
                &c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?,
            )?;
        }

        if self.config.pos_att_type.iter().any(|s| s == "p2c") {
            let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?;

            let scale = Tensor::new(
                &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32],
                &self.device,
            )?
            .sqrt()?;

            let r_pos = {
                if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? {
                    build_relative_position(
                        key_layer.dim(D::Minus2)?,
                        key_layer.dim(D::Minus2)?,
                        &self.device,
                        Some(self.position_buckets),
                        Some(self.max_relative_positions),
                    )?
                    .unsqueeze(0)?
                } else {
                    relative_pos
                }
            };

            let p2c_pos = r_pos
                .to_dtype(DType::F32)?
                .neg()?
                .broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)?
                .clamp(0f32, (att_span * 2 - 1) as f32)?;

            let p2c_att = key_layer
                .matmul(&pos_query_layer.t()?)?
                .gather(
                    &p2c_pos
                        .squeeze(0)?
                        .expand(&[
                            query_layer.dim(0)?,
                            key_layer.dim(D::Minus2)?,
                            key_layer.dim(D::Minus2)?,
                        ])?
                        .contiguous()?
                        .to_dtype(DType::U32)?,
                    D::Minus1,
                )?
                .t()?;

            score =
                score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?;
        }

        Ok(score)
    }