in src/sagemaker_training/mpi.py [0:0]
def _modelparallel_environment_command(instance_type): # type: (str) -> list[str]
"""When a task is of modelparallel
1. For torch DLC, we add NCCL_PROTO environment.
2. If the torch major version is greater 1, we add NCCL_ALGO environment.
3. If ddp_dist_backend is auto, we use smddpmprun to set up necessary environment
variables if possible.
"""
command = []
env = environment.Environment()
if env.is_modelparallel_enabled:
pytorch_version = os.environ.get("SM_DLC_TORCH_VERSION")
if pytorch_version:
logger.info(f"PyTorch version is {pytorch_version}")
# Set NCCL_PROTO
command.extend(["-x", "NCCL_PROTO=simple"])
# Set NCCL_ALGO to avoid potential hang, starting from torch2.0.0
if int(pytorch_version.split(".")[0]) > 1:
command.extend(["-x", "NCCL_ALGO=ring"])
# Use smddpmprun to set up environment variables
mp_parameters = json.loads(os.environ.get(params.SM_HP_MP_PARAMETERS, "{}"))
ddp_dist_backend = mp_parameters.get("ddp_dist_backend", "auto")
if ddp_dist_backend == "auto":
if env.is_smddpmprun_installed:
command.extend(["smddpmprun", "-i", instance_type, "--allow-bypass"])
else:
logger.info(f"{ddp_dist_backend} is used as DDP backend for training")
return command