in improved_diffusion/train_util.py [0:0]
def forward_backward(self, batch, cond):
zero_grad(self.model_params)
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()