in cm/train_util.py [0:0]
def reset_training_for_progdist(self):
assert self.training_mode == "progdist", "Training mode must be progdist"
if self.global_step > 0:
scales = self.ema_scale_fn(self.global_step)[1]
scales2 = self.ema_scale_fn(self.global_step - 1)[1]
if scales != scales2:
with th.no_grad():
update_ema(
self.teacher_model.parameters(),
self.model.parameters(),
0.0,
)
# reset optimizer
self.opt = RAdam(
self.mp_trainer.master_params,
lr=self.lr,
weight_decay=self.weight_decay,
)
self.ema_params = [
copy.deepcopy(self.mp_trainer.master_params)
for _ in range(len(self.ema_rate))
]
if scales == 2:
self.lr_anneal_steps *= 2
self.teacher_model.eval()
self.step = 0