def _setup()

in src/sagemaker_training/pytorch_xla.py [0:0]


    def _setup(self):  # type: () -> None
        logger.info("Starting distributed training through PT-XLA Runtime.")
        self._check_compatibility()

        # Set NCCL logging to info to debug customer issues
        os.environ["NCCL_DEBUG"] = "info"

        # Use `simple` protocol to handle the out-of-order data delivery from EFA
        os.environ["NCCL_PROTO"] = "simple"

        # Use GPU RDMA when available (available only in p4d.24xlarge)
        os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"

        # Use multiple connections per GPU to better saturate the EFA bandwidth
        os.environ["OFI_NCCL_NIC_DUP_CONNS"] = str(self._num_gpus)

        # Set cluster configuration for XLA runtime
        os.environ["XRT_HOST_ORDINAL"] = str(self._rank)
        os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts)
        address = "localservice:{};{}:" + str(self.WORKER_PORT)
        os.environ["XRT_WORKERS"] = "|".join(
            [address.format(i, host) for i, host in enumerate(self._hosts)]
        )
        os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus)
        if self._num_hosts > 1:
            os.environ[
                "XRT_MESH_SERVICE_ADDRESS"
            ] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}"

        logger.info("Completed environment setup for distributed training through PT-XLA Runtime.")