in training/trainer.py [0:0]
def save_checkpoint(self, epoch, checkpoint_names=None):
checkpoint_folder = self.checkpoint_conf.save_dir
makedir(checkpoint_folder)
if checkpoint_names is None:
checkpoint_names = ["checkpoint"]
if (
self.checkpoint_conf.save_freq > 0
and (int(epoch) % self.checkpoint_conf.save_freq == 0)
) or int(epoch) in self.checkpoint_conf.save_list:
checkpoint_names.append(f"checkpoint_{int(epoch)}")
checkpoint_paths = []
for ckpt_name in checkpoint_names:
checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
state_dict = exclude_params_matching_unix_pattern(
patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
)
checkpoint = {
"model": state_dict,
"optimizer": self.optim.optimizer.state_dict(),
"epoch": epoch,
"loss": self.loss.state_dict(),
"steps": self.steps,
"time_elapsed": self.time_elapsed_meter.val,
"best_meter_values": self.best_meter_values,
}
if self.optim_conf.amp.enabled:
checkpoint["scaler"] = self.scaler.state_dict()
# DDP checkpoints are only saved on rank 0 (all workers are identical)
if self.distributed_rank != 0:
return
for checkpoint_path in checkpoint_paths:
self._save_checkpoint(checkpoint, checkpoint_path)