in src/sagemaker/estimator.py [0:0]
def _distribution_configuration(self, distribution):
"""Returns a dict of distribution configurations.
Args:
distribution (dict): A dictionary with information on how to run distributed training.
Returns:
dict that
"""
distribution_config = {}
mpi_enabled = False
smdataparallel_enabled = False
p5_enabled = False
if "instance_groups" in distribution:
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
"instance_groups"
]
if "pytorchxla" in distribution:
pt_xla_enabled = distribution.get("pytorchxla").get("enabled", False)
distribution_config[self.LAUNCH_PT_XLA_ENV_NAME] = pt_xla_enabled
if "parameter_server" in distribution:
ps_enabled = distribution.get("parameter_server").get("enabled", False)
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled
if "mpi" in distribution:
mpi_dict = distribution["mpi"]
mpi_enabled = mpi_dict.get("enabled", False)
distribution_config[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
if mpi_dict.get("processes_per_host"):
distribution_config[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
"processes_per_host"
)
distribution_config[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
"custom_mpi_options", ""
)
if "smdistributed" in distribution:
# smdistributed strategy selected
if get_mp_parameters(distribution):
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
# first make sure torch_distributed is enabled if instance type is p5
torch_distributed_enabled = False
if "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get(
"enabled", False
)
smdistributed = distribution["smdistributed"]
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
if isinstance(self.instance_type, ParameterString):
p5_enabled = "p5.48xlarge" in self.instance_type.default_value
elif isinstance(self.instance_type, str):
p5_enabled = "p5.48xlarge" in self.instance_type
else:
for instance in self.instance_groups:
if "p5.48xlarge" in instance._to_request_dict().get("InstanceType", ()):
p5_enabled = True
break
img_uri = "" if self.image_uri is None else self.image_uri
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
if (
unsupported_image in img_uri and not torch_distributed_enabled
): # disabling DLC images without SMDataParallel or SMModelParallel
raise ValueError(
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
"(Could be due to CUDA version being greater than 11.)"
)
if (
not torch_distributed_enabled and p5_enabled
): # disabling p5 when torch distributed is disabled
raise ValueError(
"SMModelParallel and SMDataParallel currently do not support p5 instances."
)
# smdistributed strategy selected with supported instance type
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
distribution_config[self.INSTANCE_TYPE] = self.instance_type
if smdataparallel_enabled:
distribution_config[self.SM_DDP_CUSTOM_MPI_OPTIONS] = smdistributed[
"dataparallel"
].get("custom_mpi_options", "")
if "multi_worker_mirrored_strategy" in distribution:
mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
if mwms_enabled:
self._validate_mwms_config(distribution)
distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled
if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
"sagemaker_distribution_instance_groups"
) not in [None, []]:
raise ValueError(
"Don't set training instance groups while no distribution strategies enabled!"
)
return distribution_config