fn multistep_uni_p_bh_update()

in candle-transformers/src/models/stable_diffusion/uni_pc.rs [383:483]


    fn multistep_uni_p_bh_update(&self, sample: &Tensor, timestep: usize) -> Result<Tensor> {
        let step_index = self.step_index(timestep);
        let ns = &self.schedule;
        let model_outputs = self.state.model_outputs();
        let Some(m0) = &model_outputs[model_outputs.len() - 1] else {
            return Err(Error::Msg(
                "Expected model output for predictor update".to_string(),
            ));
        };

        let (t0, tt) = (timestep, self.timestep(self.step_index(timestep) + 1));
        let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
        let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
        let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));

        let h = lambda_t - lambda_s0;
        let device = sample.device();

        let (mut rks, mut d1s) = (vec![], vec![]);
        for i in 1..self.state.order() {
            let ti = self.timestep(step_index.saturating_sub(i + 1));
            let Some(mi) = model_outputs
                .get(model_outputs.len().saturating_sub(i + 1))
                .into_iter()
                .flatten()
                .next()
            else {
                return Err(Error::Msg(
                    "Expected model output for predictor update".to_string(),
                ));
            };
            let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
            let lambda_si = alpha_si.ln() - sigma_si.ln();
            let rk = (lambda_si - lambda_s0) / h;
            rks.push(rk);
            d1s.push(((mi - m0)? / rk)?);
        }
        rks.push(1.0);
        let rks = Tensor::new(rks, device)?;
        let (mut r, mut b) = (vec![], vec![]);

        let hh = h.neg();
        let h_phi_1 = hh.exp_m1();
        let mut h_phi_k = h_phi_1 / hh - 1.;
        let mut factorial_i = 1.;

        let b_h = match self.config.solver_type {
            SolverType::Bh1 => hh,
            SolverType::Bh2 => hh.exp_m1(),
        };

        for i in 1..self.state.order() + 1 {
            r.push(rks.powf(i as f64 - 1.)?);
            b.push(h_phi_k * factorial_i / b_h);
            factorial_i = i as f64 + 1.;
            h_phi_k = h_phi_k / hh - 1. / factorial_i;
        }

        let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
        let (d1s, rhos_p) = match d1s.len() {
            0 => (None, None),
            _ => {
                let rhos_p = match self.state.order() {
                    2 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
                    _ => {
                        let ((r1, r2), b1) = (r.dims2()?, b.dims1()?);
                        let inverse = linalg::inverse(&r.i((..(r1 - 1), ..(r2 - 1)))?)?;
                        let b = b.i(..(b1 - 1))?;
                        b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
                    }
                };

                (Some(Tensor::stack(&d1s, 1)?), Some(rhos_p))
            }
        };

        let x_t_ = ((sigma_t / sigma_s0 * sample)? - (alpha_t * h_phi_1 * m0)?)?;
        if let (Some(d1s), Some(rhos_p)) = (d1s, rhos_p) {
            use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
            let output_shape = m0.shape().clone();
            let pred_res = TensordotGeneral {
                lhs_permutation: Permutation { dims: vec![0] },
                rhs_permutation: Permutation {
                    dims: vec![1, 0, 2, 3, 4],
                },
                tensordot_fixed_position: TensordotFixedPosition {
                    len_uncontracted_lhs: 1,
                    len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
                    len_contracted_axes: d1s.dim(1)?,
                    output_shape,
                },
                output_permutation: Permutation {
                    dims: vec![0, 1, 2, 3],
                },
            }
            .eval(&rhos_p, &d1s)?;
            x_t_ - (alpha_t * b_h * pred_res)?
        } else {
            Ok(x_t_)
        }
    }