in cm/train_util.py [0:0]
def forward_backward(self, batch, cond):
self.mp_trainer.zero_grad()
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
ema, num_scales = self.ema_scale_fn(self.global_step)
if self.training_mode == "progdist":
if num_scales == self.ema_scale_fn(0)[1]:
compute_losses = functools.partial(
self.diffusion.progdist_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.teacher_model,
target_diffusion=self.teacher_diffusion,
model_kwargs=micro_cond,
)
else:
compute_losses = functools.partial(
self.diffusion.progdist_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
target_diffusion=self.diffusion,
model_kwargs=micro_cond,
)
elif self.training_mode == "consistency_distillation":
compute_losses = functools.partial(
self.diffusion.consistency_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
teacher_model=self.teacher_model,
teacher_diffusion=self.teacher_diffusion,
model_kwargs=micro_cond,
)
elif self.training_mode == "consistency_training":
compute_losses = functools.partial(
self.diffusion.consistency_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
model_kwargs=micro_cond,
)
else:
raise ValueError(f"Unknown training mode {self.training_mode}")
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
self.mp_trainer.backward(loss)