in src/sagemaker/huggingface/training_compiler/config.py [0:0]
def validate(cls, estimator):
"""Checks if SageMaker Training Compiler is configured correctly.
Args:
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.
Raises:
ValueError: Raised if the requested configuration is not compatible
with SageMaker Training Compiler.
"""
super(TrainingCompilerConfig, cls).validate(estimator)
if estimator.pytorch_version:
if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or (
Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
):
error_helper_string = (
"SageMaker Training Compiler is only supported "
"with HuggingFace PyTorch 1.9-1.11. "
"Received pytorch_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.pytorch_version))
if estimator.image_uri:
error_helper_string = (
"Overriding the image URI is currently not supported "
"for SageMaker Training Compiler."
"Specify the following parameters to run the Hugging Face training job "
"with SageMaker Training Compiler enabled: "
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
)
raise ValueError(error_helper_string)
if estimator.distribution:
pt_xla_present = "pytorchxla" in estimator.distribution
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
if pt_xla_enabled:
if estimator.tensorflow_version:
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
"tensorflow_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.tensorflow_version))
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
error_helper_string = (
"Distribution mechanism 'pytorchxla' is currently only supported for "
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
" Received pytorch_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.pytorch_version))
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
logger.warning(
"Consider using instances with EFA support when "
"training with PyTorch >= 1.11 and SageMaker Training Compiler "
"enabled. SageMaker Training Compiler leverages EFA to provide better "
"performance for distributed training."
)
if not pt_xla_present:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
error_helper_string = (
"'pytorchxla' is the only distribution mechanism currently supported "
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
" Received distribution={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.distribution))
elif estimator.instance_count and estimator.instance_count > 1:
if estimator.pytorch_version:
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
logger.warning(
"Consider setting 'distribution' to 'pytorchxla' for distributed "
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
)