in pytorch/sagemakercv/utils/checkpoint.py [0:0]
def save(self, name, **kwargs):
if not self.save_dir:
return
if not self.save_to_disk:
return
nhwc = kwargs.get("nhwc", False)
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
# transpose to NCHW before saving as checkpoint if NHWC is used
if nhwc:
transpose_checkpoint_model_state_nhwc_to_nchw(data["model"])
if self.optimizer is not None:
transpose_optimizer_state_nhwc_to_nchw(self.model, data["optimizer"])
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
self.tag_last_checkpoint(save_file)
# Convert back to NHWC if NHWC layout is used, needed for optimizer buffers
if nhwc:
if self.optimizer is not None:
transpose_optimizer_state_nchw_to_nhwc(self.model, self.optimizer.state_dict())