def _distribution_configuration()

in src/sagemaker/estimator.py [0:0]


    def _distribution_configuration(self, distribution):
        """Returns a dict of distribution configurations.

        Args:
            distribution (dict): A dictionary with information on how to run distributed training.

        Returns:
            dict that
        """
        distribution_config = {}

        mpi_enabled = False
        smdataparallel_enabled = False
        p5_enabled = False
        if "instance_groups" in distribution:
            distribution_config["sagemaker_distribution_instance_groups"] = distribution[
                "instance_groups"
            ]

        if "pytorchxla" in distribution:
            pt_xla_enabled = distribution.get("pytorchxla").get("enabled", False)
            distribution_config[self.LAUNCH_PT_XLA_ENV_NAME] = pt_xla_enabled

        if "parameter_server" in distribution:
            ps_enabled = distribution.get("parameter_server").get("enabled", False)
            distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled

        if "mpi" in distribution:
            mpi_dict = distribution["mpi"]
            mpi_enabled = mpi_dict.get("enabled", False)
            distribution_config[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled

            if mpi_dict.get("processes_per_host"):
                distribution_config[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
                    "processes_per_host"
                )

            distribution_config[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
                "custom_mpi_options", ""
            )

        if "smdistributed" in distribution:
            # smdistributed strategy selected
            if get_mp_parameters(distribution):
                distribution_config["mp_parameters"] = get_mp_parameters(distribution)
            # first make sure torch_distributed is enabled if instance type is p5
            torch_distributed_enabled = False
            if "torch_distributed" in distribution:
                torch_distributed_enabled = distribution.get("torch_distributed").get(
                    "enabled", False
                )
            smdistributed = distribution["smdistributed"]
            smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
            if isinstance(self.instance_type, ParameterString):
                p5_enabled = "p5.48xlarge" in self.instance_type.default_value
            elif isinstance(self.instance_type, str):
                p5_enabled = "p5.48xlarge" in self.instance_type
            else:
                for instance in self.instance_groups:
                    if "p5.48xlarge" in instance._to_request_dict().get("InstanceType", ()):
                        p5_enabled = True
                        break

            img_uri = "" if self.image_uri is None else self.image_uri
            for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
                if (
                    unsupported_image in img_uri and not torch_distributed_enabled
                ):  # disabling DLC images without SMDataParallel or SMModelParallel
                    raise ValueError(
                        f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
                        "(Could be due to CUDA version being greater than 11.)"
                    )
            if (
                not torch_distributed_enabled and p5_enabled
            ):  # disabling p5 when torch distributed is disabled
                raise ValueError(
                    "SMModelParallel and SMDataParallel currently do not support p5 instances."
                )
            # smdistributed strategy selected with supported instance type
            distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
            distribution_config[self.INSTANCE_TYPE] = self.instance_type
            if smdataparallel_enabled:
                distribution_config[self.SM_DDP_CUSTOM_MPI_OPTIONS] = smdistributed[
                    "dataparallel"
                ].get("custom_mpi_options", "")

        if "multi_worker_mirrored_strategy" in distribution:
            mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
            if mwms_enabled:
                self._validate_mwms_config(distribution)
            distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled

        if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
            "sagemaker_distribution_instance_groups"
        ) not in [None, []]:
            raise ValueError(
                "Don't set training instance groups while no distribution strategies enabled!"
            )

        return distribution_config