in backends/candle/src/models/mpnet.rs [330:361]
fn relative_position_bucket(
&self,
relative_position: &Tensor,
max_distance: i64,
) -> Result<Tensor> {
let device = relative_position.device();
let num_buckets = (self.relative_attention_num_buckets / 2) as f64;
let max_exact = num_buckets / 2.0;
let max_distance_log = (max_distance as f64 / max_exact).ln();
let scale = (num_buckets - max_exact) / max_distance_log;
let mut ret = Tensor::zeros_like(relative_position)?;
let n = relative_position.to_dtype(DType::F32)?.neg()?;
ret = ret.add(&(&n.lt(0.0)?.to_dtype(DType::F32)? * num_buckets)?.to_dtype(DType::I64)?)?;
let n = n.abs()?;
let is_small = n.lt(max_exact)?;
let log_val = (n.clone() / max_exact)?.log()?;
let val_if_large = (max_exact + (log_val * scale)?)?;
let val_if_large = val_if_large
.minimum(&Tensor::full(
(num_buckets - 1.0) as f32,
val_if_large.shape(),
device,
)?)?
.to_dtype(DType::I64)?;
ret.add(&is_small.where_cond(&n.clone().to_dtype(DType::I64)?, &val_if_large)?)
}