fn multistep_uni_c_bh_update()

in candle-transformers/src/models/stable_diffusion/uni_pc.rs [485:598]


    fn multistep_uni_c_bh_update(
        &self,
        model_output: &Tensor,
        model_outputs: &[Option<Tensor>],
        last_sample: &Tensor,
        sample: &Tensor,
        timestep: usize,
    ) -> Result<Tensor> {
        let step_index = self.step_index(timestep);
        let Some(m0) = model_outputs.last().into_iter().flatten().next() else {
            return Err(Error::Msg(
                "Expected model output for corrector update".to_string(),
            ));
        };
        let model_t = model_output;
        let (x, _xt) = (last_sample, sample);

        let (t0, tt, ns) = (
            self.timestep(self.step_index(timestep) - 1),
            timestep,
            &self.schedule,
        );
        let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
        let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
        let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));

        let h = lambda_t - lambda_s0;
        let device = sample.device();

        let (mut rks, mut d1s) = (vec![], vec![]);
        for i in 1..self.state.order() {
            let ti = self.timestep(step_index.saturating_sub(i + 1));
            let Some(mi) = model_outputs
                .get(model_outputs.len().saturating_sub(i + 1))
                .into_iter()
                .flatten()
                .next()
            else {
                return Err(Error::Msg(
                    "Expected model output for corrector update".to_string(),
                ));
            };
            let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
            let lambda_si = alpha_si.ln() - sigma_si.ln();
            let rk = (lambda_si - lambda_s0) / h;
            rks.push(rk);
            d1s.push(((mi - m0)? / rk)?);
        }
        rks.push(1.0);
        let rks = Tensor::new(rks, device)?;
        let (mut r, mut b) = (vec![], vec![]);

        let hh = h.neg();
        let h_phi_1 = hh.exp_m1();
        let mut h_phi_k = h_phi_1 / hh - 1.;
        let mut factorial_i = 1.;

        let b_h = match self.config.solver_type {
            SolverType::Bh1 => hh,
            SolverType::Bh2 => hh.exp_m1(),
        };

        for i in 1..self.state.order() + 1 {
            r.push(rks.powf(i as f64 - 1.)?);
            b.push(h_phi_k * factorial_i / b_h);
            factorial_i = i as f64 + 1.;
            h_phi_k = h_phi_k / hh - 1. / factorial_i;
        }

        let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
        let d1s = match d1s.len() {
            0 => None,
            _ => Some(Tensor::stack(&d1s, 1)?),
        };
        let rhos_c = match self.state.order() {
            1 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
            _ => {
                let inverse = linalg::inverse(&r)?;
                b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
            }
        };

        let x_t_ = ((sigma_t / sigma_s0 * x)? - (alpha_t * h_phi_1 * m0)?)?;
        let corr_res = d1s
            .map(|d1s| {
                use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
                let output_shape = x_t_.shape().clone();
                TensordotGeneral {
                    lhs_permutation: Permutation { dims: vec![0] },
                    rhs_permutation: Permutation {
                        dims: vec![1, 0, 2, 3, 4],
                    },
                    tensordot_fixed_position: TensordotFixedPosition {
                        len_uncontracted_lhs: 1,
                        len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
                        len_contracted_axes: d1s.dim(1)?,
                        output_shape,
                    },
                    output_permutation: Permutation {
                        dims: vec![0, 1, 2, 3],
                    },
                }
                .eval(&rhos_c.i(..rhos_c.dims()[0] - 1)?, &d1s)
            })
            .unwrap_or_else(|| Tensor::zeros_like(m0))?;

        let d1_t = (model_t - m0)?;
        let x_t = (x_t_
            - (alpha_t
                * b_h
                * (corr_res + rhos_c.i(rhos_c.dims()[0] - 1)?.broadcast_mul(&d1_t)?)?)?)?;

        Ok(x_t)
    }