optimum/neuron/models/inference/backend/config.py (138 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/config.py
from typing import List, Optional, Union
import torch
from ....configuration_utils import NeuronConfig, register_neuron_config
from ....utils import map_torch_dtype
NEURON_CONFIG_FILE = "neuron_config.json"
def to_dict(obj):
if type(obj) is dict:
return {k: to_dict(v) for k, v in obj.items()}
elif type(obj) is list:
return [to_dict(v) for v in obj]
elif hasattr(obj, "__dict__"):
return {k: to_dict(v) for k, v in obj.__dict__.items()}
elif type(obj) is torch.dtype:
return str(obj).split(".")[1]
else:
return obj
class IncompatibleConfigError(ValueError):
pass
@register_neuron_config
class NxDNeuronConfig(NeuronConfig):
"""
Base config class for inference in NxD.
This class contains attributes that are needed for various inference
optimization/features in NxD.
These attributes are a subset of the attributes in the original NxDI config class.
"""
def __init__(
self,
checkpoint_id: str = None,
checkpoint_revision: str = None,
batch_size: Optional[int] = 1,
max_batch_size: Optional[int] = None,
continuous_batching: Optional[bool] = False,
speculation_length: Optional[int] = 0,
sequence_length: Optional[int] = 128,
tp_degree: Optional[int] = 1,
ep_degree: Optional[int] = 1,
pp_degree: Optional[int] = 1,
torch_dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16,
rpl_reduce_dtype: Optional[Union[str, torch.dtype]] = None,
n_active_tokens: Optional[int] = None,
max_context_length: Optional[int] = None,
output_logits: Optional[bool] = False,
padding_side: Optional[str] = "right",
fused_qkv: Optional[bool] = False,
vocab_parallel: Optional[bool] = False,
sequence_parallel_enabled: Optional[bool] = False,
is_chunked_prefill: Optional[bool] = False,
flash_decoding_enabled: Optional[bool] = False,
async_mode: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
attn_kernel_enabled: Optional[bool] = False,
qkv_kernel_enabled: Optional[bool] = False,
mlp_kernel_enabled: Optional[bool] = False,
mlp_kernel_fuse_residual_add: Optional[bool] = False,
enable_bucketing: Optional[bool] = False,
target: Optional[str] = None, # Set to "trn2" for trn2
logical_nc_config: Optional[int] = 1,
cc_pipeline_tiling_factor: Optional[int] = 2,
num_cores_per_group: Optional[int] = 1,
on_device_sampling: Optional[bool] = False,
max_topk: Optional[int] = 256,
start_rank_id: Optional[int] = 0,
local_ranks_size: Optional[int] = None,
capacity_factor: float = None,
glu_mlp: bool = True,
) -> None:
# TODO: these flags are suposed to work in NxDI. Either make them work or remove them
if is_chunked_prefill:
raise ValueError("`is_chunked_prefill` is not supported in optimum-neuron.")
if flash_decoding_enabled:
raise ValueError("`flash_decoding_enabled` is not supported in optimum-neuron.")
if async_mode:
raise ValueError("`async_mode` is not supported in optimum-neuron.")
if qkv_kernel_enabled or mlp_kernel_enabled:
raise ValueError("`qkv_kernel_enabled` and `mlp_kernel_enabled` are not supported for trn1 chips.")
if vocab_parallel:
raise ValueError("`vocab_parallel` is not supported in optimum-neuron.")
if qk_layernorm:
raise ValueError(
"`qk_layernorm` is not supported in optimum-neuron. It is actually a modeling flag that affects the attention layer."
)
# Required to retrieve a checkpoint from the hub
self.checkpoint_id = checkpoint_id
self.checkpoint_revision = checkpoint_revision
# Basic config for inference in NxD
self.batch_size = batch_size
self.sequence_length = sequence_length
self.tp_degree = tp_degree
self.torch_dtype = torch_dtype
if isinstance(self.torch_dtype, str):
self.torch_dtype = map_torch_dtype(self.torch_dtype)
self.n_active_tokens = self.sequence_length if n_active_tokens is None else n_active_tokens
self.output_logits = output_logits
self.padding_side = padding_side
self.rpl_reduce_dtype = torch_dtype if rpl_reduce_dtype is None else rpl_reduce_dtype
if isinstance(self.rpl_reduce_dtype, str):
self.rpl_reduce_dtype = map_torch_dtype(self.rpl_reduce_dtype)
# fallback to sequence_length is for compatibility with vllm
self.max_context_length = max_context_length
if self.max_context_length is None:
self.max_context_length = sequence_length
# Graph transforms
self.fused_qkv = fused_qkv
# Functional parallelism
self.vocab_parallel = vocab_parallel
self.sequence_parallel_enabled = sequence_parallel_enabled
self.is_chunked_prefill = is_chunked_prefill
# Continuous batching
# TODO: Check if we really need different batch size for CTE and TKG, given
# that we anyway provide two different config instance for them.
self.continuous_batching = continuous_batching
self.max_batch_size = batch_size if max_batch_size is None else max_batch_size
# On-device sampling
self.on_device_sampling = on_device_sampling
self.max_topk = max_topk
# async
self.async_mode = async_mode
# Bucketing
self.enable_bucketing = enable_bucketing
# Speculative decoding
self.speculation_length = speculation_length
if self.speculation_length > 0:
if self.async_mode:
raise IncompatibleConfigError("Speculative Decoding is not yet supported with async.")
if self.on_device_sampling:
raise IncompatibleConfigError("Speculative decoding is incompatible with on-device sampling")
# Distributed config
self.pp_degree = pp_degree
self.ep_degree = ep_degree
# QK layer normalization
self.qk_layernorm = qk_layernorm
# Multi-node
# TODO: Check if start_rank_id can be modified dynamically at runtime
# Otherwise, we need multiple exports for different start_rank_id
self.start_rank_id = start_rank_id
self.local_ranks_size = local_ranks_size
if self.local_ranks_size is None:
self.local_ranks_size = self.world_size
# Flash decoding
self.flash_decoding_enabled = flash_decoding_enabled
self.num_cores_per_group = num_cores_per_group
# Kernels
self.attn_kernel_enabled = attn_kernel_enabled
self.qkv_kernel_enabled = qkv_kernel_enabled
self.mlp_kernel_enabled = mlp_kernel_enabled
self.mlp_kernel_fuse_residual_add = mlp_kernel_fuse_residual_add
# compiler flags
self.logical_nc_config = logical_nc_config
self.cc_pipeline_tiling_factor = cc_pipeline_tiling_factor
self.target = target
# MoE specific
self.capacity_factor = float(capacity_factor) if capacity_factor is not None else None
self.glu_mlp = glu_mlp
@property
def ctx_batch_size(self) -> int:
return 1 if self.continuous_batching else self.batch_size
@property
def tkg_batch_size(self) -> int:
return self.batch_size
@property
def world_size(self) -> int:
"""
The total number of ranks in the distributed setup.
"""
return self.tp_degree * self.pp_degree * self.ep_degree
@property
def weights_to_skip_layout_optimization(self) -> List[str]:
"""
List of weights to skip layout optimization.
Can be overridden by subclasses to specify weights that should not be optimized.
"""
return []