in train.py [0:0]
def training_step(H, data_input, target, vae, ema_vae, optimizer, iterate):
t0 = time.time()
vae.zero_grad()
stats = vae.forward(data_input, target)
stats['elbo'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), H.grad_clip).item()
distortion_nans = torch.isnan(stats['distortion']).sum()
rate_nans = torch.isnan(stats['rate']).sum()
stats.update(dict(rate_nans=0 if rate_nans == 0 else 1, distortion_nans=0 if distortion_nans == 0 else 1))
stats = get_cpu_stats_over_ranks(stats)
skipped_updates = 1
# only update if no rank has a nan and if the grad norm is below a specific threshold
if stats['distortion_nans'] == 0 and stats['rate_nans'] == 0 and (H.skip_threshold == -1 or grad_norm < H.skip_threshold):
optimizer.step()
skipped_updates = 0
update_ema(vae, ema_vae, H.ema_rate)
t1 = time.time()
stats.update(skipped_updates=skipped_updates, iter_time=t1 - t0, grad_norm=grad_norm)
return stats