optimum/neuron/models/inference/backend/modules/decoder/decoder_wrapper.py (242 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import List
import torch
import torch.nn.functional as F
from neuronx_distributed.trace.model_builder import BaseModelInstance
from torch_neuronx import BucketModelConfig
from transformers import PretrainedConfig
from ...config import NxDNeuronConfig
from ...model_wrapper import NxDModelWrapper
from ..autobucketing import (
get_context_encoder_bk,
get_generation_model_bk,
)
from ..generation.sampling import prepare_sampling_params
CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model"
TOKEN_GENERATION_MODEL_TAG = "token_generation_model"
SPECULATION_MODEL_TAG = "speculation_model"
def get_bucket_model_config_from_tag(
tag, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: List[int]
):
bucket_degree = len(buckets)
if bucket_degree == 1:
return None
pad_token = config.pad_token_id
# NOTE: KV Cache preprocessing is done within the model and not the
# shared buffer preprocessor due to lack of support of non-contiguous
# slicing of nrt tensors via the NRT API.
if tag == CONTEXT_ENCODING_MODEL_TAG:
return BucketModelConfig(
bucket_kernel=get_context_encoder_bk,
bucket_kernel_constant_args=(
torch.tensor(buckets),
neuron_config.padding_side,
pad_token,
),
shared_state_buffer=None,
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
)
elif tag == TOKEN_GENERATION_MODEL_TAG or tag == SPECULATION_MODEL_TAG:
return BucketModelConfig(
bucket_kernel=get_generation_model_bk,
bucket_kernel_constant_args=(
torch.tensor(buckets),
neuron_config.padding_side,
0,
),
shared_state_buffer=None,
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
)
else:
raise ValueError(
f"The supplied tag: {tag} is not supported for Bucketing. Only {CONTEXT_ENCODING_MODEL_TAG} and {TOKEN_GENERATION_MODEL_TAG} are supported"
)
class NxDDecoderWrapper(NxDModelWrapper):
def __init__(
self,
config: PretrainedConfig,
neuron_config: NxDNeuronConfig,
buckets: List[int],
bucket_n_active_tokens: bool,
model_cls,
tag="",
priority_model_idx: int = None,
model_init_kwargs={},
) -> None:
super().__init__(tag, priority_model_idx)
self.config = config
self.neuron_config = neuron_config
self.buckets = buckets
self.bucket_n_active_tokens = bucket_n_active_tokens
if not self.neuron_config.torch_dtype:
self.neuron_config.torch_dtype = torch.float32
if config.pad_token_id is None:
config.pad_token_id = 0
self.model_cls = model_cls
self.model = None
self.is_compiled = False
self.serialize_base_path = None
base_compile_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR", "/tmp/nxd_model/")
self.compiler_workdir = os.path.join(base_compile_work_dir, self.tag)
self.model_init_kwargs = model_init_kwargs
self.async_mode = self.neuron_config.async_mode
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
self.model = self.model_cls(self.config, self.neuron_config)
self.model.load_state_dict(state_dict, strict=strict, assign=assign)
def input_generator(
self,
):
inputs = []
for bucket in self.buckets:
n_active_tokens = bucket if self.bucket_n_active_tokens else self.neuron_config.n_active_tokens
input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32)
position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32)
# Get the count of sampling params currently supported.
sampling_params_len = prepare_sampling_params(1).shape[1]
sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32)
inputs.append((input_ids, attention_mask, position_ids, seq_ids, sampling_params))
return inputs
def get_model_instance(self):
return DecoderModelInstance(
model_cls=self.model_cls,
config=self.config,
neuron_config=self.neuron_config,
buckets=self.buckets,
**self.model_init_kwargs,
)
def get_bucket_config(self):
return get_bucket_model_config_from_tag(self.tag, self.config, self.neuron_config, self.buckets)
def _forward_with_pad(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
# pad the inputs up to the compiled batch size in the end
def pad_helper(tensor, pad_type="zeros"):
VALID_PAD_TYPES = {"zeros", "ones", "repeat_first_batchline"}
assert pad_type in VALID_PAD_TYPES, f"Found {pad_type=}, but valid pad types are {VALID_PAD_TYPES}"
if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
return tensor
padded_shape = list(tensor.shape)
padded_shape[0] = self.neuron_config.batch_size
if pad_type == "repeat_first_batchline":
# pad with first batch line values instead of zeros, to reduce chances of NaN
padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
else:
fill_value = 0 if pad_type == "zeros" else 1
padded_tensor = torch.full(padded_shape, fill_value=fill_value, dtype=tensor.dtype)
padded_tensor[: tensor.shape[0]] = tensor
return padded_tensor
padded_args = []
for arg in (input_ids, attention_mask, position_ids):
padded_args.append(pad_helper(arg, pad_type="repeat_first_batchline"))
# need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,
# 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0,
# 2, 1, 3] instead
seq_ids_list = seq_ids.tolist()
padded_seq_ids = torch.tensor(
seq_ids_list + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list],
dtype=seq_ids.dtype,
)
padded_args.append(padded_seq_ids)
# pad sampling params by repeating first batchline
padded_sampling_params = pad_helper(sampling_params, pad_type="repeat_first_batchline")
padded_args.append(padded_sampling_params)
outputs = self._forward(*padded_args)
# note that we don't do index select here as it should already be handled, simply sliced out padding here
logits = outputs
return logits[: seq_ids.shape[0]]
def _forward(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
needs_reordering = False
if self.tag == TOKEN_GENERATION_MODEL_TAG and self.neuron_config.continuous_batching:
# if continuous batching is enabled, we need to ensure that the inputs are at the expected positions
orig_seq_ids = seq_ids.clone()
needs_reordering = not torch.equal(seq_ids, torch.arange(seq_ids.shape[0]))
if needs_reordering:
sorting_index = torch.argsort(seq_ids)
seq_ids = torch.index_select(seq_ids, 0, sorting_index)
input_ids = torch.index_select(input_ids, 0, sorting_index)
attention_mask = torch.index_select(attention_mask, 0, sorting_index)
position_ids = torch.index_select(position_ids, 0, sorting_index)
sampling_params = torch.index_select(sampling_params, 0, sorting_index)
outputs = self.model(input_ids, attention_mask, position_ids, seq_ids, sampling_params)
if needs_reordering:
# if we reordered the inputs, we need to reorder the outputs as well
outputs = torch.index_select(outputs, 0, orig_seq_ids)
return outputs
def convert_int64_to_int32(self, *args):
"""
Convert int64 args to int32 to match compiled input types.
Neuron compiler handles int32 better than int64. Context: P165494809
"""
return [t.to(torch.int32) if t.dtype == torch.int64 else t for t in args]
def pad_to_max_compiled_seq(self, *args):
if self.tag == CONTEXT_ENCODING_MODEL_TAG:
to_pad = args[:3]
pad_lengths = [self.neuron_config.max_context_length - arg.shape[1] for arg in to_pad]
tensor_pad_vals = [self.config.pad_token_id, 0, 1]
padded_args = [
F.pad(arg, (0, pad_len), "constant", pad_val)
for arg, pad_val, pad_len in zip(to_pad, tensor_pad_vals, pad_lengths)
]
args = (*padded_args, *args[3:])
else:
input_ids, attention_mask, *rest_of_args = args
pad_len = self.neuron_config.sequence_length - attention_mask.shape[1]
padded_attention_mask = F.pad(attention_mask, (0, pad_len), "constant", 0)
args = (input_ids, padded_attention_mask, *rest_of_args)
return args
def _get_async_output(self, ranked_async_tensor):
outputs = [[async_tensor[0].cpu()] for async_tensor in ranked_async_tensor]
return outputs[0][0]
def forward(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
input_ids, attention_mask, position_ids, seq_ids = self.convert_int64_to_int32(
input_ids, attention_mask, position_ids, seq_ids
)
input_ids, attention_mask, position_ids, seq_ids = self.pad_to_max_compiled_seq(
input_ids, attention_mask, position_ids, seq_ids
)
input_batch_size = seq_ids.shape[0]
if input_batch_size > self.neuron_config.max_batch_size:
raise ValueError(
f"Input batch size {input_batch_size} exceeds the maximum batch size {self.neuron_config.max_batch_size}."
)
elif input_batch_size == self.neuron_config.batch_size:
return self._forward(input_ids, attention_mask, position_ids, seq_ids, sampling_params)
cur_batch = 0
output_logits = []
logging.debug(
f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}"
)
args = (input_ids, attention_mask, position_ids, seq_ids, sampling_params)
while cur_batch < input_batch_size:
if cur_batch + self.neuron_config.batch_size <= input_batch_size:
# we only process part of the input to run
logging.debug(f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}")
outputs = self._forward(*[arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args])
else:
# we need to pad the input to run
logging.debug(
f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}"
)
outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args])
output_logits.append(outputs)
cur_batch += self.neuron_config.batch_size
if self.async_mode:
# block on all requests here, since this is output manipulation
output_logits = [self._get_async_output(ranked_logits) for ranked_logits in output_logits]
return torch.cat(output_logits, dim=0)
class DecoderModelInstance(BaseModelInstance):
def __init__(
self, model_cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: List[int], **kwargs
):
self.model_cls = model_cls
self.module = None
self.input_output_aliases = None
self.config = config
self.neuron_config = neuron_config
self.buckets = buckets
self.kwargs = kwargs if kwargs is not None else {}
def initialize_process_group(self, world_size):
self.model_cls.initialize_process_group(world_size)
def load_module(self):
float_model = self.model_cls(self.config, self.neuron_config, **self.kwargs)
float_model.eval()
if self.neuron_config.torch_dtype != torch.float32:
float_model._apply(
lambda t: t.to(self.neuron_config.torch_dtype)
if t.is_floating_point() and t.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]
else t
)
self.module = float_model
def get(self, bucket_rank, **kwargs):
if bucket_rank is not None:
self.module.n_positions = self.buckets[bucket_rank]
# Currently we have to init an input_output_aliases map for
# each buckets, otherwise it will fail the aliasing setup when
# generating HLO
self.input_output_aliases = {}
num_output_from_trace = 1 if not self.neuron_config.output_logits else 2
# TODO: This else block is a short-term fix for Llava/ViT models to use DecoderModelInstance.
# Long-term, these models should use a different implementation of BaseModelInstance.
if self.module.kv_mgr is not None:
past_key_values = self.module.kv_mgr.past_key_values
else:
past_key_values = self.module.past_key_values
for i in range(len(past_key_values)):
self.input_output_aliases[past_key_values[i]] = num_output_from_trace + i
return self.module, self.input_output_aliases