def validate()

in src/sagemaker/huggingface/training_compiler/config.py [0:0]


    def validate(cls, estimator):
        """Checks if SageMaker Training Compiler is configured correctly.

        Args:
            estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
                If SageMaker Training Compiler is enabled, it will validate whether
                the estimator is configured to be compatible with Training Compiler.

        Raises:
            ValueError: Raised if the requested configuration is not compatible
                        with SageMaker Training Compiler.
        """

        super(TrainingCompilerConfig, cls).validate(estimator)

        if estimator.pytorch_version:
            if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or (
                Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
            ):
                error_helper_string = (
                    "SageMaker Training Compiler is only supported "
                    "with HuggingFace PyTorch 1.9-1.11. "
                    "Received pytorch_version={} which is unsupported."
                )
                raise ValueError(error_helper_string.format(estimator.pytorch_version))

        if estimator.image_uri:
            error_helper_string = (
                "Overriding the image URI is currently not supported "
                "for SageMaker Training Compiler."
                "Specify the following parameters to run the Hugging Face training job "
                "with SageMaker Training Compiler enabled: "
                "transformer_version, tensorflow_version or pytorch_version, and compiler_config."
            )
            raise ValueError(error_helper_string)

        if estimator.distribution:
            pt_xla_present = "pytorchxla" in estimator.distribution
            pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
            if pt_xla_enabled:
                if estimator.tensorflow_version:
                    error_helper_string = (
                        "Distribution mechanism 'pytorchxla' is currently only supported for "
                        "PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
                        "tensorflow_version={} which is unsupported."
                    )
                    raise ValueError(error_helper_string.format(estimator.tensorflow_version))
                if estimator.pytorch_version:
                    if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
                        error_helper_string = (
                            "Distribution mechanism 'pytorchxla' is currently only supported for "
                            "PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
                            " Received pytorch_version={} which is unsupported."
                        )
                        raise ValueError(error_helper_string.format(estimator.pytorch_version))
                    if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
                        logger.warning(
                            "Consider using instances with EFA support when "
                            "training with PyTorch >= 1.11 and SageMaker Training Compiler "
                            "enabled. SageMaker Training Compiler leverages EFA to provide better "
                            "performance for distributed training."
                        )
            if not pt_xla_present:
                if estimator.pytorch_version:
                    if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
                        error_helper_string = (
                            "'pytorchxla' is the only distribution mechanism currently supported "
                            "for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
                            " Received distribution={} which is unsupported."
                        )
                        raise ValueError(error_helper_string.format(estimator.distribution))
        elif estimator.instance_count and estimator.instance_count > 1:
            if estimator.pytorch_version:
                if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
                    logger.warning(
                        "Consider setting 'distribution' to 'pytorchxla' for distributed "
                        "training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
                    )