def _save_snapshot()

in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]


    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.get_snapshot_path())
        print(f"Epoch {epoch} | Training snapshot saved at {self.output_model_path}")