in candle-transformers/src/models/stable_diffusion/uni_pc.rs [485:598]
fn multistep_uni_c_bh_update(
&self,
model_output: &Tensor,
model_outputs: &[Option<Tensor>],
last_sample: &Tensor,
sample: &Tensor,
timestep: usize,
) -> Result<Tensor> {
let step_index = self.step_index(timestep);
let Some(m0) = model_outputs.last().into_iter().flatten().next() else {
return Err(Error::Msg(
"Expected model output for corrector update".to_string(),
));
};
let model_t = model_output;
let (x, _xt) = (last_sample, sample);
let (t0, tt, ns) = (
self.timestep(self.step_index(timestep) - 1),
timestep,
&self.schedule,
);
let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
let h = lambda_t - lambda_s0;
let device = sample.device();
let (mut rks, mut d1s) = (vec![], vec![]);
for i in 1..self.state.order() {
let ti = self.timestep(step_index.saturating_sub(i + 1));
let Some(mi) = model_outputs
.get(model_outputs.len().saturating_sub(i + 1))
.into_iter()
.flatten()
.next()
else {
return Err(Error::Msg(
"Expected model output for corrector update".to_string(),
));
};
let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
let lambda_si = alpha_si.ln() - sigma_si.ln();
let rk = (lambda_si - lambda_s0) / h;
rks.push(rk);
d1s.push(((mi - m0)? / rk)?);
}
rks.push(1.0);
let rks = Tensor::new(rks, device)?;
let (mut r, mut b) = (vec![], vec![]);
let hh = h.neg();
let h_phi_1 = hh.exp_m1();
let mut h_phi_k = h_phi_1 / hh - 1.;
let mut factorial_i = 1.;
let b_h = match self.config.solver_type {
SolverType::Bh1 => hh,
SolverType::Bh2 => hh.exp_m1(),
};
for i in 1..self.state.order() + 1 {
r.push(rks.powf(i as f64 - 1.)?);
b.push(h_phi_k * factorial_i / b_h);
factorial_i = i as f64 + 1.;
h_phi_k = h_phi_k / hh - 1. / factorial_i;
}
let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
let d1s = match d1s.len() {
0 => None,
_ => Some(Tensor::stack(&d1s, 1)?),
};
let rhos_c = match self.state.order() {
1 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
_ => {
let inverse = linalg::inverse(&r)?;
b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
}
};
let x_t_ = ((sigma_t / sigma_s0 * x)? - (alpha_t * h_phi_1 * m0)?)?;
let corr_res = d1s
.map(|d1s| {
use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
let output_shape = x_t_.shape().clone();
TensordotGeneral {
lhs_permutation: Permutation { dims: vec![0] },
rhs_permutation: Permutation {
dims: vec![1, 0, 2, 3, 4],
},
tensordot_fixed_position: TensordotFixedPosition {
len_uncontracted_lhs: 1,
len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
len_contracted_axes: d1s.dim(1)?,
output_shape,
},
output_permutation: Permutation {
dims: vec![0, 1, 2, 3],
},
}
.eval(&rhos_c.i(..rhos_c.dims()[0] - 1)?, &d1s)
})
.unwrap_or_else(|| Tensor::zeros_like(m0))?;
let d1_t = (model_t - m0)?;
let x_t = (x_t_
- (alpha_t
* b_h
* (corr_res + rhos_c.i(rhos_c.dims()[0] - 1)?.broadcast_mul(&d1_t)?)?)?)?;
Ok(x_t)
}