in candle-transformers/src/models/stable_diffusion/uni_pc.rs [383:483]
fn multistep_uni_p_bh_update(&self, sample: &Tensor, timestep: usize) -> Result<Tensor> {
let step_index = self.step_index(timestep);
let ns = &self.schedule;
let model_outputs = self.state.model_outputs();
let Some(m0) = &model_outputs[model_outputs.len() - 1] else {
return Err(Error::Msg(
"Expected model output for predictor update".to_string(),
));
};
let (t0, tt) = (timestep, self.timestep(self.step_index(timestep) + 1));
let (sigma_t, sigma_s0) = (ns.sigma_t(tt), ns.sigma_t(t0));
let (alpha_t, _alpha_s0) = (ns.alpha_t(tt), ns.alpha_t(t0));
let (lambda_t, lambda_s0) = (ns.lambda_t(tt), ns.lambda_t(t0));
let h = lambda_t - lambda_s0;
let device = sample.device();
let (mut rks, mut d1s) = (vec![], vec![]);
for i in 1..self.state.order() {
let ti = self.timestep(step_index.saturating_sub(i + 1));
let Some(mi) = model_outputs
.get(model_outputs.len().saturating_sub(i + 1))
.into_iter()
.flatten()
.next()
else {
return Err(Error::Msg(
"Expected model output for predictor update".to_string(),
));
};
let (alpha_si, sigma_si) = (ns.alpha_t(ti), ns.sigma_t(ti));
let lambda_si = alpha_si.ln() - sigma_si.ln();
let rk = (lambda_si - lambda_s0) / h;
rks.push(rk);
d1s.push(((mi - m0)? / rk)?);
}
rks.push(1.0);
let rks = Tensor::new(rks, device)?;
let (mut r, mut b) = (vec![], vec![]);
let hh = h.neg();
let h_phi_1 = hh.exp_m1();
let mut h_phi_k = h_phi_1 / hh - 1.;
let mut factorial_i = 1.;
let b_h = match self.config.solver_type {
SolverType::Bh1 => hh,
SolverType::Bh2 => hh.exp_m1(),
};
for i in 1..self.state.order() + 1 {
r.push(rks.powf(i as f64 - 1.)?);
b.push(h_phi_k * factorial_i / b_h);
factorial_i = i as f64 + 1.;
h_phi_k = h_phi_k / hh - 1. / factorial_i;
}
let (r, b) = (Tensor::stack(&r, 0)?, Tensor::new(b, device)?);
let (d1s, rhos_p) = match d1s.len() {
0 => (None, None),
_ => {
let rhos_p = match self.state.order() {
2 => Tensor::new(&[0.5f64], m0.device())?.to_dtype(m0.dtype())?,
_ => {
let ((r1, r2), b1) = (r.dims2()?, b.dims1()?);
let inverse = linalg::inverse(&r.i((..(r1 - 1), ..(r2 - 1)))?)?;
let b = b.i(..(b1 - 1))?;
b.broadcast_mul(&inverse)?.sum(1)?.to_dtype(m0.dtype())?
}
};
(Some(Tensor::stack(&d1s, 1)?), Some(rhos_p))
}
};
let x_t_ = ((sigma_t / sigma_s0 * sample)? - (alpha_t * h_phi_1 * m0)?)?;
if let (Some(d1s), Some(rhos_p)) = (d1s, rhos_p) {
use linalg::{Permutation, TensordotFixedPosition, TensordotGeneral};
let output_shape = m0.shape().clone();
let pred_res = TensordotGeneral {
lhs_permutation: Permutation { dims: vec![0] },
rhs_permutation: Permutation {
dims: vec![1, 0, 2, 3, 4],
},
tensordot_fixed_position: TensordotFixedPosition {
len_uncontracted_lhs: 1,
len_uncontracted_rhs: output_shape.dims().iter().product::<usize>(),
len_contracted_axes: d1s.dim(1)?,
output_shape,
},
output_permutation: Permutation {
dims: vec![0, 1, 2, 3],
},
}
.eval(&rhos_p, &d1s)?;
x_t_ - (alpha_t * b_h * pred_res)?
} else {
Ok(x_t_)
}
}