def _modelparallel_environment_command()

in src/sagemaker_training/mpi.py [0:0]


def _modelparallel_environment_command(instance_type):  # type: (str) -> list[str]
    """When a task is of modelparallel
    1. For torch DLC, we add NCCL_PROTO environment.
    2. If the torch major version is greater 1, we add NCCL_ALGO environment.
    3. If ddp_dist_backend is auto, we use smddpmprun to set up necessary environment
       variables if possible.
    """
    command = []
    env = environment.Environment()
    if env.is_modelparallel_enabled:
        pytorch_version = os.environ.get("SM_DLC_TORCH_VERSION")
        if pytorch_version:
            logger.info(f"PyTorch version is {pytorch_version}")
            # Set NCCL_PROTO
            command.extend(["-x", "NCCL_PROTO=simple"])
            # Set NCCL_ALGO to avoid potential hang, starting from torch2.0.0
            if int(pytorch_version.split(".")[0]) > 1:
                command.extend(["-x", "NCCL_ALGO=ring"])

        # Use smddpmprun to set up environment variables
        mp_parameters = json.loads(os.environ.get(params.SM_HP_MP_PARAMETERS, "{}"))
        ddp_dist_backend = mp_parameters.get("ddp_dist_backend", "auto")
        if ddp_dist_backend == "auto":
            if env.is_smddpmprun_installed:
                command.extend(["smddpmprun", "-i", instance_type, "--allow-bypass"])
        else:
            logger.info(f"{ddp_dist_backend} is used as DDP backend for training")
    return command