def validate_mp_config()

in src/sagemaker/fw_utils.py [0:0]


def validate_mp_config(config):
    """Validate the configuration dictionary for model parallelism.

    Args:
       config (dict): Dictionary holding configuration keys and values.

    Raises:
        ValueError: If any of the keys have incorrect values.
    """

    def validate_positive(key):
        try:
            if not isinstance(config[key], int) or config[key] < 1:
                raise ValueError(f"The number of {key} must be a positive integer.")
        except KeyError:
            pass

    def validate_in(key, vals):
        try:
            if config[key] not in vals:
                raise ValueError(f"{key} must be a value in: {vals}.")
        except KeyError:
            pass

    def validate_bool(keys):
        validate_in(keys, [True, False])

    validate_in("pipeline", ["simple", "interleaved", "_only_forward"])
    validate_in("placement_strategy", ["spread", "cluster"])
    validate_in("optimize", ["speed", "memory"])

    for key in ["microbatches", "partitions", "active_microbatches"]:
        validate_positive(key)

    for key in [
        "auto_partition",
        "contiguous",
        "load_partition",
        "horovod",
        "ddp",
        "deterministic_server",
    ]:
        validate_bool(key)

    if "partition_file" in config and not isinstance(config.get("partition_file"), str):
        raise ValueError("'partition_file' must be a str.")

    if config.get("auto_partition") is False and "default_partition" not in config:
        raise ValueError("default_partition must be supplied if auto_partition is set to False!")

    if "default_partition" in config and config["default_partition"] >= config["partitions"]:
        raise ValueError("default_partition must be less than the number of partitions!")

    if "memory_weight" in config and (
        config["memory_weight"] > 1.0 or config["memory_weight"] < 0.0
    ):
        raise ValueError("memory_weight must be between 0.0 and 1.0!")

    if "ddp_port" in config and "ddp" not in config:
        raise ValueError("`ddp_port` needs `ddp` to be set as well")

    if "ddp_dist_backend" in config and "ddp" not in config:
        raise ValueError("`ddp_dist_backend` needs `ddp` to be set as well")

    if "ddp_port" in config:
        if not isinstance(config["ddp_port"], int) or config["ddp_port"] < 0:
            value = config["ddp_port"]
            raise ValueError(f"Invalid port number {value}.")

    if config.get("horovod", False) and config.get("ddp", False):
        raise ValueError("'ddp' and 'horovod' cannot be simultaneously enabled.")