in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(
f"[GPU-{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}"
)
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
self._run_batch(source, targets)