in candle-transformers/src/models/debertav2.rs [544:700]
fn disentangled_attention_bias(
&self,
query_layer: Tensor,
key_layer: Tensor,
relative_pos: Option<&Tensor>,
rel_embeddings: Tensor,
scale_factor: usize,
) -> Result<Tensor> {
let mut relative_pos = relative_pos.map_or(
build_relative_position(
query_layer.dim(D::Minus2)?,
key_layer.dim(D::Minus2)?,
&self.device,
Some(self.position_buckets),
Some(self.max_relative_positions),
)?,
|pos| pos.clone(),
);
relative_pos = match relative_pos.dims().len() {
2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?,
3 => relative_pos.unsqueeze(1)?,
other => {
bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}")
}
};
let att_span = self.pos_ebd_size;
let rel_embeddings = rel_embeddings
.narrow(0, 0, (att_span * 2) as usize)?
.unsqueeze(0)?;
let mut pos_query_layer: Option<Tensor> = None;
let mut pos_key_layer: Option<Tensor> = None;
let repeat_with = query_layer.dim(0)? / self.num_attention_heads;
if self.share_att_key {
pos_query_layer = Some(
self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)?
.repeat(repeat_with)?,
);
pos_key_layer = Some(
self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)?
.repeat(repeat_with)?,
)
} else {
if self.config.pos_att_type.iter().any(|s| s == "c2p") {
pos_key_layer = Some(
self.transpose_for_scores(
&self
.pos_key_proj
.as_ref()
.context(
"Need pos_key_proj when share_att_key is false or not specified",
)?
.forward(&rel_embeddings)?,
)?
.repeat(repeat_with)?,
)
}
if self.config.pos_att_type.iter().any(|s| s == "p2c") {
pos_query_layer = Some(self.transpose_for_scores(&self
.pos_query_proj
.as_ref()
.context("Need a pos_query_proj when share_att_key is false or not specified")?
.forward(&rel_embeddings)?)?.repeat(repeat_with)?)
}
}
let mut score = Tensor::new(&[0 as f32], &self.device)?;
if self.config.pos_att_type.iter().any(|s| s == "c2p") {
let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?;
let scale = Tensor::new(
&[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32],
&self.device,
)?
.sqrt()?;
let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?;
let c2p_pos = relative_pos
.broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)?
.clamp(0 as f32, (att_span * 2 - 1) as f32)?;
c2p_att = c2p_att.gather(
&c2p_pos
.squeeze(0)?
.expand(&[
query_layer.dim(0)?,
query_layer.dim(1)?,
relative_pos.dim(D::Minus1)?,
])?
.contiguous()?,
D::Minus1,
)?;
score = score.broadcast_add(
&c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?,
)?;
}
if self.config.pos_att_type.iter().any(|s| s == "p2c") {
let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?;
let scale = Tensor::new(
&[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32],
&self.device,
)?
.sqrt()?;
let r_pos = {
if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? {
build_relative_position(
key_layer.dim(D::Minus2)?,
key_layer.dim(D::Minus2)?,
&self.device,
Some(self.position_buckets),
Some(self.max_relative_positions),
)?
.unsqueeze(0)?
} else {
relative_pos
}
};
let p2c_pos = r_pos
.to_dtype(DType::F32)?
.neg()?
.broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)?
.clamp(0f32, (att_span * 2 - 1) as f32)?;
let p2c_att = key_layer
.matmul(&pos_query_layer.t()?)?
.gather(
&p2c_pos
.squeeze(0)?
.expand(&[
query_layer.dim(0)?,
key_layer.dim(D::Minus2)?,
key_layer.dim(D::Minus2)?,
])?
.contiguous()?
.to_dtype(DType::U32)?,
D::Minus1,
)?
.t()?;
score =
score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?;
}
Ok(score)
}