def save()

in pytorch/sagemakercv/utils/checkpoint.py [0:0]


    def save(self, name, **kwargs):
        if not self.save_dir:
            return

        if not self.save_to_disk:
            return
        nhwc = kwargs.get("nhwc", False)
        data = {}
        data["model"] = self.model.state_dict()
        if self.optimizer is not None:
            data["optimizer"] = self.optimizer.state_dict()
        if self.scheduler is not None:
            data["scheduler"] = self.scheduler.state_dict()
        data.update(kwargs)
        # transpose to NCHW before saving as checkpoint if NHWC is used
        if nhwc:
            transpose_checkpoint_model_state_nhwc_to_nchw(data["model"])
            if self.optimizer is not None:
                transpose_optimizer_state_nhwc_to_nchw(self.model, data["optimizer"])
        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
        self.logger.info("Saving checkpoint to {}".format(save_file))
        torch.save(data, save_file)
        self.tag_last_checkpoint(save_file)
        # Convert back to NHWC if NHWC layout is used, needed for optimizer buffers
        if nhwc:
            if self.optimizer is not None:
                transpose_optimizer_state_nchw_to_nhwc(self.model, self.optimizer.state_dict())