src/optimum/nvidia/export/config.py (161 lines of code) (raw):

from dataclasses import dataclass from logging import getLogger from os import PathLike from typing import TYPE_CHECKING, Optional, Union from warnings import warn from tensorrt_llm import BuildConfig from tensorrt_llm import Mapping as ShardingInfo from tensorrt_llm.bindings import QuantMode from tensorrt_llm.plugin import PluginConfig from tensorrt_llm.plugin.plugin import ContextFMHAType from transformers import AutoConfig from optimum.nvidia.lang import DataType from optimum.nvidia.utils.nvml import is_post_hopper from optimum.utils import NormalizedConfig if TYPE_CHECKING: from transformers import PretrainedConfig INFER_NUM_LOCAL_GPUS = -1 LOGGER = getLogger() @dataclass class ExportConfig: dtype: str max_input_len: int max_output_len: int max_batch_size: int # Optional parameters max_beam_width: int = 1 max_num_tokens: int = -1 enabled_chunked_context: int = False sharding: Optional[ShardingInfo] = None optimization_level: int = 3 def __post_init__(self): if self.max_batch_size < 1: raise ValueError(f"max_batch_size should >= 1, got {self.max_batch_size}") @staticmethod def from_pretrained( model_id_or_path: Union[str, PathLike], max_batch_size: int = 1 ) -> "ExportConfig": return ExportConfig.from_config( AutoConfig.from_pretrained(model_id_or_path), max_batch_size ) @staticmethod def from_config( config: Union[NormalizedConfig, "PretrainedConfig"], max_batch_size: int = 1 ) -> "ExportConfig": if not isinstance(config, NormalizedConfig): config = NormalizedConfig(config) dtype = DataType.from_torch(config.torch_dtype).value max_input_len = config.max_position_embeddings max_output_len = config.max_position_embeddings econfig = ExportConfig( dtype=dtype, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, ) # Initialize sharing information with single shard econfig.with_sharding() econfig.validate() return econfig def validate(self) -> "ExportConfig": if self.optimization_level < 0: raise ValueError( f"optimization_level should be >= 0, got {self.optimization_level}" ) if self.max_num_tokens == -1: if self.enabled_chunked_context: # Should be N * tokens_per_block (8192 is the default) self.max_num_tokens = 8192 # hardcode for now warn( f"max_num_tokens set to {self.max_num_tokens} with chunked context enabled might not be optimal." ) else: self.max_num_tokens = self.max_batch_size * self.max_input_len // 2 LOGGER.debug(f"Inferred max_num_tokens={self.max_num_tokens}") return self @property def plugin_config(self) -> "PluginConfig": config = PluginConfig() config.gemm_plugin = "auto" config.gpt_attention_plugin = "auto" config.set_context_fmha(ContextFMHAType.enabled) config.enable_paged_kv_cache(32) config.use_paged_context_fmha = True if self.sharding.world_size > 1: config.set_nccl_plugin() if DataType(self.dtype) == DataType.FLOAT8: config.gemm_swiglu_plugin = True return config def to_builder_config( self, qmode: Optional["QuantMode"] = None, plugin_config: Optional[PluginConfig] = None, ) -> "BuildConfig": self.validate() plugin_config = plugin_config or self.plugin_config plugin_config.multiple_profiles = True if qmode: plugin_config.use_fp8_context_fmha = ( qmode.has_fp8_qdq() or qmode.has_fp8_kv_cache() ) # Low latency GeMM plugin is only available for sm90+ and float8 weigths and activations if qmode.has_fp8_qdq() and qmode.has_fp8_kv_cache() and is_post_hopper(): plugin_config.low_latency_gemm_plugin = "fp8" if qmode.is_weight_only(): plugin_config.weight_only_groupwise_quant_matmul_plugin = "auto" weight_sparsity = qmode.sparsity is not None else: weight_sparsity = False return BuildConfig( max_input_len=self.max_input_len, max_seq_len=self.max_output_len, max_batch_size=self.max_batch_size, max_beam_width=self.max_beam_width, max_num_tokens=self.max_num_tokens, plugin_config=plugin_config, use_fused_mlp=True, weight_sparsity=weight_sparsity, ) def with_sharding( self, tp: int = 1, pp: int = 1, gpus_per_node: int = 8, sharding: Optional[ShardingInfo] = None, ) -> "ExportConfig": self.sharding = sharding or ShardingInfo( tp_size=tp, pp_size=pp, world_size=tp * pp, gpus_per_node=gpus_per_node ) return self def auto_parallel( config: "ExportConfig", world_size: int = INFER_NUM_LOCAL_GPUS ) -> "ExportConfig": """ Helper to infer the most suitable parallelization strategy to apply to the model with respect to the local hardware. :param config: `ExportConfig` the quantization process should be added to :param world_size: Number of GPUs to consider when discovering automatic parallelization strategies :return: `ExportConfig` """ # Infer number of GPUs on the system if world_size < 1: from optimum.nvidia.utils.nvml import get_device_count world_size = get_device_count() LOGGER.info(f"Found {world_size} GPUs on the system") # Handle all the different cases (0, 1, N > 1) if world_size == 0: raise ValueError("No GPU found") elif world_size == 1: return config.with_sharding(tp=1, pp=1, gpus_per_node=world_size) else: LOGGER.info(f"Creating auto-parallelization strategy on {world_size}-GPUs") LOGGER.warning( "Auto-parallelization strategy is currently in beta and might not be optimal" ) if world_size == 2: return config.with_sharding(tp=2, pp=1, gpus_per_node=world_size) elif world_size == 4: return config.with_sharding(tp=2, pp=2, gpus_per_node=world_size) elif world_size == 8: return config.with_sharding(tp=4, pp=2, gpus_per_node=world_size) else: raise ValueError( f"Unsupported number of GPUs: {world_size}. " "Please open-up and issue on the optimum-nvidia repository: " "https://github.com/huggingface/optimum-nvidia" ) def sharded(config: "ExportConfig", tp: int = 1, pp: int = 1) -> "ExportConfig": """ Helper to specific the parallelization strategy to apply to the model :param config: `ExportConfig` the quantization process should be added to :param tp: Tensor Parallelism degree to apply (`int` >= 1) :param pp: Pipeline Parallelism degree to apply (`int` >= 1) :return: `ExportConfig` """ if tp < 1: raise ValueError(f"Tensor Parallelism (tp) should be >= 1 (got: tp={tp})") if pp < 1: raise ValueError(f"Pipeline Parallelism (pp) should be >= 1 (got: pp={pp})") return config.with_sharding( sharding=ShardingInfo(tp_size=tp, pp_size=pp, world_size=tp * pp) )