in cm/train_util.py [0:0]
def save(self):
def save_checkpoint(rate, params):
state_dict = self.mp_trainer.master_params_to_state_dict(params)
if dist.get_rank() == 0:
logger.log(f"saving model {rate}...")
if not rate:
filename = f"model{(self.step+self.resume_step):06d}.pt"
else:
filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(state_dict, f)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)
if dist.get_rank() == 0:
with bf.BlobFile(
bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
"wb",
) as f:
th.save(self.opt.state_dict(), f)
# Save model parameters last to prevent race conditions where a restart
# loads model at step N, but opt/ema state isn't saved for step N.
save_checkpoint(0, self.mp_trainer.master_params)
dist.barrier()