in cm/train_util.py [0:0]
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
took_step = self.mp_trainer.optimize(self.opt)
if took_step:
self._update_ema()
if self.target_model:
self._update_target_ema()
if self.training_mode == "progdist":
self.reset_training_for_progdist()
self.step += 1
self.global_step += 1
self._anneal_lr()
self.log_step()