in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
if self.global_rank == 0 and epoch % self.save_every == 0:
self._save_snapshot(epoch)
# save model after training
self._save_model()