in threestudio/models/guidance/zero123_guidance.py [0:0]
def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred):
# use only 50 timesteps, and find nearest of those to t
self.scheduler.set_timesteps(50)
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
bs = (
min(self.cfg.max_items_eval, latents_noisy.shape[0])
if self.cfg.max_items_eval > 0
else latents_noisy.shape[0]
) # batch size
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
:bs
].unsqueeze(
-1
) # sized [bs,50] > [bs,1]
idxs = torch.min(large_enough_idxs, dim=1)[1]
t = self.scheduler.timesteps_gpu[idxs]
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)
# get prev latent
latents_1step = []
pred_1orig = []
for b in range(bs):
step_output = self.scheduler.step(
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
)
latents_1step.append(step_output["prev_sample"])
pred_1orig.append(step_output["pred_original_sample"])
latents_1step = torch.cat(latents_1step)
pred_1orig = torch.cat(pred_1orig)
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)
latents_final = []
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
c = {
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
}
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.reshape(1)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, c)
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# get prev latent
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
"prev_sample"
]
latents_final.append(latents)
latents_final = torch.cat(latents_final)
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)
return {
"bs": bs,
"noise_levels": fracs,
"imgs_noisy": imgs_noisy,
"imgs_1step": imgs_1step,
"imgs_1orig": imgs_1orig,
"imgs_final": imgs_final,
}