in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]
def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
}
torch.save(snapshot, self.get_snapshot_path())
print(f"Epoch {epoch} | Training snapshot saved at {self.output_model_path}")