optimum_benchmark/launchers/torchrun/config.py (30 lines of code) (raw):
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from ..config import LauncherConfig
@dataclass
class TorchrunConfig(LauncherConfig):
name: str = "torchrun"
_target_: str = "optimum_benchmark.launchers.torchrun.launcher.TorchrunLauncher"
# Minimum amount of nodes that the user function will be launched on.
# Elastic agent ensures that the user function start only when the min_nodes amount enters the rendezvous.
min_nodes: int = 1
# Maximum amount of nodes that the user function will be launched on.
max_nodes: int = 1
# On each node the elastic agent will launch this amount of workers that will execute user defined function.
nproc_per_node: int = 2
# User defined role of the worker (defaults to "trainer").
role: str = "benchmarker"
# The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
monitor_interval: int = 30
# The name of the rdzv store.
rdzv_id: str = str(uuid.uuid4())
# rdzv_backend to use in the rendezvous (etcd).
rdzv_backend: str = "c10d"
# The endpoint of the rdzv sync. storage.
rdzv_endpoint: str = "localhost:0"
# Key, value pair that specifies rendezvous specific configuration.
rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"rank": 0, "timeout": -1})
# The timeout in seconds that is used by the elastic agent to wait for the workers to enter the rendezvous.
rdzv_timeout: int = -1
# The maximum amount of restarts that elastic agent will conduct on workers before failure.
max_restarts: int = 0
# The method is used by the elastic agent to start the workers (spawn, fork, forkserver).
start_method: str = "spawn"
# address of the local node if any. If not set, a lookup on the local machine's FQDN will be performed.
local_addr: Optional[str] = None
# The socket ifname
socket_ifname: Optional[str] = None
def __post_init__(self):
super().__post_init__()
if self.start_method not in ["spawn", "fork"]:
raise ValueError(f"start_method must be one of ['spawn', 'fork'], got {self.start_method}")
if self.min_nodes != self.max_nodes:
raise ValueError(
f"min_nodes and max_nodes must be equal for a reproducible benchmark, got {self.min_nodes} and {self.max_nodes}"
)