in src/sagemaker/fw_utils.py [0:0]
def validate_mp_config(config):
"""Validate the configuration dictionary for model parallelism.
Args:
config (dict): Dictionary holding configuration keys and values.
Raises:
ValueError: If any of the keys have incorrect values.
"""
def validate_positive(key):
try:
if not isinstance(config[key], int) or config[key] < 1:
raise ValueError(f"The number of {key} must be a positive integer.")
except KeyError:
pass
def validate_in(key, vals):
try:
if config[key] not in vals:
raise ValueError(f"{key} must be a value in: {vals}.")
except KeyError:
pass
def validate_bool(keys):
validate_in(keys, [True, False])
validate_in("pipeline", ["simple", "interleaved", "_only_forward"])
validate_in("placement_strategy", ["spread", "cluster"])
validate_in("optimize", ["speed", "memory"])
for key in ["microbatches", "partitions", "active_microbatches"]:
validate_positive(key)
for key in [
"auto_partition",
"contiguous",
"load_partition",
"horovod",
"ddp",
"deterministic_server",
]:
validate_bool(key)
if "partition_file" in config and not isinstance(config.get("partition_file"), str):
raise ValueError("'partition_file' must be a str.")
if config.get("auto_partition") is False and "default_partition" not in config:
raise ValueError("default_partition must be supplied if auto_partition is set to False!")
if "default_partition" in config and config["default_partition"] >= config["partitions"]:
raise ValueError("default_partition must be less than the number of partitions!")
if "memory_weight" in config and (
config["memory_weight"] > 1.0 or config["memory_weight"] < 0.0
):
raise ValueError("memory_weight must be between 0.0 and 1.0!")
if "ddp_port" in config and "ddp" not in config:
raise ValueError("`ddp_port` needs `ddp` to be set as well")
if "ddp_dist_backend" in config and "ddp" not in config:
raise ValueError("`ddp_dist_backend` needs `ddp` to be set as well")
if "ddp_port" in config:
if not isinstance(config["ddp_port"], int) or config["ddp_port"] < 0:
value = config["ddp_port"]
raise ValueError(f"Invalid port number {value}.")
if config.get("horovod", False) and config.get("ddp", False):
raise ValueError("'ddp' and 'horovod' cannot be simultaneously enabled.")