fn relative_position_bucket()

in backends/candle/src/models/mpnet.rs [330:361]


    fn relative_position_bucket(
        &self,
        relative_position: &Tensor,
        max_distance: i64,
    ) -> Result<Tensor> {
        let device = relative_position.device();

        let num_buckets = (self.relative_attention_num_buckets / 2) as f64;
        let max_exact = num_buckets / 2.0;
        let max_distance_log = (max_distance as f64 / max_exact).ln();
        let scale = (num_buckets - max_exact) / max_distance_log;

        let mut ret = Tensor::zeros_like(relative_position)?;
        let n = relative_position.to_dtype(DType::F32)?.neg()?;

        ret = ret.add(&(&n.lt(0.0)?.to_dtype(DType::F32)? * num_buckets)?.to_dtype(DType::I64)?)?;
        let n = n.abs()?;

        let is_small = n.lt(max_exact)?;

        let log_val = (n.clone() / max_exact)?.log()?;
        let val_if_large = (max_exact + (log_val * scale)?)?;

        let val_if_large = val_if_large
            .minimum(&Tensor::full(
                (num_buckets - 1.0) as f32,
                val_if_large.shape(),
                device,
            )?)?
            .to_dtype(DType::I64)?;
        ret.add(&is_small.where_cond(&n.clone().to_dtype(DType::I64)?, &val_if_large)?)
    }