optimum/neuron/models/training/config.py (59 lines of code) (raw):
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import torch
from ...configuration_utils import NeuronConfig, register_neuron_config
from ...utils import is_neuronx_distributed_available
from ...utils.torch_xla_and_neuronx_initialization import init_process_group
if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state
@dataclass
@register_neuron_config
class TrainingNeuronConfig(NeuronConfig):
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
kv_size_multiplier: Optional[int] = None
pipeline_parallel_size: int = 1
virtual_pipeline_parallel_size: int = 1
pipeline_parallel_num_microbatches: int = 1
pipeline_parallel_use_zero1_optimizer: bool = False
gradient_checkpointing: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None
num_local_ranks_per_step: int = 8
use_xser: bool = True
async_save: bool = False
fuse_qkv: bool = False
recompute_causal_mask: bool = True
def __post_init__(self):
if self.tensor_parallel_size < 1:
raise ValueError(f"The tensor parallel size must be >= 1, but {self.tensor_parallel_size} was given here.")
if self.pipeline_parallel_size < 1:
raise ValueError(
f"The pipeline parallel size must be >= 1, but {self.pipeline_parallel_size} was given here."
)
if isinstance(self.checkpoint_dir, str):
self.checkpoint_dir = Path(self.checkpoint_dir)
if not torch.distributed.is_initialized():
init_process_group()
if not torch.distributed.is_initialized():
raise ValueError(
"Neuron training requires torch distributed to be initialized. "
"You can initialize it by running `torchrun`."
)
if not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.tensor_parallel_size,
pipeline_model_parallel_size=self.pipeline_parallel_size,
)
def auto_kv_size_multiplier(self, num_key_value_heads: int) -> int:
kv_size_multiplier = max(1, self.tensor_parallel_size // num_key_value_heads)
if self.kv_size_multiplier is not None and self.kv_size_multiplier != kv_size_multiplier:
raise ValueError(
"A kv size multiplier was already specified and is different from the inferred one: "
f"{self.kv_size_multiplier}"
)
return kv_size_multiplier
@property
def should_parallelize(self):
return self.tensor_parallel_size > 1 or self.pipeline_parallel_size > 1