def train()

in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]


    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.global_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)
        # save model after training
        self._save_model()