in src/sagemaker_training/pytorch_xla.py [0:0]
def _setup(self): # type: () -> None
logger.info("Starting distributed training through PT-XLA Runtime.")
self._check_compatibility()
# Set NCCL logging to info to debug customer issues
os.environ["NCCL_DEBUG"] = "info"
# Use `simple` protocol to handle the out-of-order data delivery from EFA
os.environ["NCCL_PROTO"] = "simple"
# Use GPU RDMA when available (available only in p4d.24xlarge)
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
# Use multiple connections per GPU to better saturate the EFA bandwidth
os.environ["OFI_NCCL_NIC_DUP_CONNS"] = str(self._num_gpus)
# Set cluster configuration for XLA runtime
os.environ["XRT_HOST_ORDINAL"] = str(self._rank)
os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts)
address = "localservice:{};{}:" + str(self.WORKER_PORT)
os.environ["XRT_WORKERS"] = "|".join(
[address.format(i, host) for i, host in enumerate(self._hosts)]
)
os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus)
if self._num_hosts > 1:
os.environ[
"XRT_MESH_SERVICE_ADDRESS"
] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}"
logger.info("Completed environment setup for distributed training through PT-XLA Runtime.")