optimum/neuron/models/inference/backend/modules/decoder/decoder_wrapper.py (242 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import List import torch import torch.nn.functional as F from neuronx_distributed.trace.model_builder import BaseModelInstance from torch_neuronx import BucketModelConfig from transformers import PretrainedConfig from ...config import NxDNeuronConfig from ...model_wrapper import NxDModelWrapper from ..autobucketing import ( get_context_encoder_bk, get_generation_model_bk, ) from ..generation.sampling import prepare_sampling_params CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model" TOKEN_GENERATION_MODEL_TAG = "token_generation_model" SPECULATION_MODEL_TAG = "speculation_model" def get_bucket_model_config_from_tag( tag, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: List[int] ): bucket_degree = len(buckets) if bucket_degree == 1: return None pad_token = config.pad_token_id # NOTE: KV Cache preprocessing is done within the model and not the # shared buffer preprocessor due to lack of support of non-contiguous # slicing of nrt tensors via the NRT API. if tag == CONTEXT_ENCODING_MODEL_TAG: return BucketModelConfig( bucket_kernel=get_context_encoder_bk, bucket_kernel_constant_args=( torch.tensor(buckets), neuron_config.padding_side, pad_token, ), shared_state_buffer=None, func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)], ) elif tag == TOKEN_GENERATION_MODEL_TAG or tag == SPECULATION_MODEL_TAG: return BucketModelConfig( bucket_kernel=get_generation_model_bk, bucket_kernel_constant_args=( torch.tensor(buckets), neuron_config.padding_side, 0, ), shared_state_buffer=None, func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)], ) else: raise ValueError( f"The supplied tag: {tag} is not supported for Bucketing. Only {CONTEXT_ENCODING_MODEL_TAG} and {TOKEN_GENERATION_MODEL_TAG} are supported" ) class NxDDecoderWrapper(NxDModelWrapper): def __init__( self, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: List[int], bucket_n_active_tokens: bool, model_cls, tag="", priority_model_idx: int = None, model_init_kwargs={}, ) -> None: super().__init__(tag, priority_model_idx) self.config = config self.neuron_config = neuron_config self.buckets = buckets self.bucket_n_active_tokens = bucket_n_active_tokens if not self.neuron_config.torch_dtype: self.neuron_config.torch_dtype = torch.float32 if config.pad_token_id is None: config.pad_token_id = 0 self.model_cls = model_cls self.model = None self.is_compiled = False self.serialize_base_path = None base_compile_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR", "/tmp/nxd_model/") self.compiler_workdir = os.path.join(base_compile_work_dir, self.tag) self.model_init_kwargs = model_init_kwargs self.async_mode = self.neuron_config.async_mode def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): self.model = self.model_cls(self.config, self.neuron_config) self.model.load_state_dict(state_dict, strict=strict, assign=assign) def input_generator( self, ): inputs = [] for bucket in self.buckets: n_active_tokens = bucket if self.bucket_n_active_tokens else self.neuron_config.n_active_tokens input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32) position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32) # Get the count of sampling params currently supported. sampling_params_len = prepare_sampling_params(1).shape[1] sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32) inputs.append((input_ids, attention_mask, position_ids, seq_ids, sampling_params)) return inputs def get_model_instance(self): return DecoderModelInstance( model_cls=self.model_cls, config=self.config, neuron_config=self.neuron_config, buckets=self.buckets, **self.model_init_kwargs, ) def get_bucket_config(self): return get_bucket_model_config_from_tag(self.tag, self.config, self.neuron_config, self.buckets) def _forward_with_pad(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params): # pad the inputs up to the compiled batch size in the end def pad_helper(tensor, pad_type="zeros"): VALID_PAD_TYPES = {"zeros", "ones", "repeat_first_batchline"} assert pad_type in VALID_PAD_TYPES, f"Found {pad_type=}, but valid pad types are {VALID_PAD_TYPES}" if tensor is None or tensor.shape[0] == self.neuron_config.batch_size: return tensor padded_shape = list(tensor.shape) padded_shape[0] = self.neuron_config.batch_size if pad_type == "repeat_first_batchline": # pad with first batch line values instead of zeros, to reduce chances of NaN padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype) else: fill_value = 0 if pad_type == "zeros" else 1 padded_tensor = torch.full(padded_shape, fill_value=fill_value, dtype=tensor.dtype) padded_tensor[: tensor.shape[0]] = tensor return padded_tensor padded_args = [] for arg in (input_ids, attention_mask, position_ids): padded_args.append(pad_helper(arg, pad_type="repeat_first_batchline")) # need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1, # 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0, # 2, 1, 3] instead seq_ids_list = seq_ids.tolist() padded_seq_ids = torch.tensor( seq_ids_list + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list], dtype=seq_ids.dtype, ) padded_args.append(padded_seq_ids) # pad sampling params by repeating first batchline padded_sampling_params = pad_helper(sampling_params, pad_type="repeat_first_batchline") padded_args.append(padded_sampling_params) outputs = self._forward(*padded_args) # note that we don't do index select here as it should already be handled, simply sliced out padding here logits = outputs return logits[: seq_ids.shape[0]] def _forward(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params): needs_reordering = False if self.tag == TOKEN_GENERATION_MODEL_TAG and self.neuron_config.continuous_batching: # if continuous batching is enabled, we need to ensure that the inputs are at the expected positions orig_seq_ids = seq_ids.clone() needs_reordering = not torch.equal(seq_ids, torch.arange(seq_ids.shape[0])) if needs_reordering: sorting_index = torch.argsort(seq_ids) seq_ids = torch.index_select(seq_ids, 0, sorting_index) input_ids = torch.index_select(input_ids, 0, sorting_index) attention_mask = torch.index_select(attention_mask, 0, sorting_index) position_ids = torch.index_select(position_ids, 0, sorting_index) sampling_params = torch.index_select(sampling_params, 0, sorting_index) outputs = self.model(input_ids, attention_mask, position_ids, seq_ids, sampling_params) if needs_reordering: # if we reordered the inputs, we need to reorder the outputs as well outputs = torch.index_select(outputs, 0, orig_seq_ids) return outputs def convert_int64_to_int32(self, *args): """ Convert int64 args to int32 to match compiled input types. Neuron compiler handles int32 better than int64. Context: P165494809 """ return [t.to(torch.int32) if t.dtype == torch.int64 else t for t in args] def pad_to_max_compiled_seq(self, *args): if self.tag == CONTEXT_ENCODING_MODEL_TAG: to_pad = args[:3] pad_lengths = [self.neuron_config.max_context_length - arg.shape[1] for arg in to_pad] tensor_pad_vals = [self.config.pad_token_id, 0, 1] padded_args = [ F.pad(arg, (0, pad_len), "constant", pad_val) for arg, pad_val, pad_len in zip(to_pad, tensor_pad_vals, pad_lengths) ] args = (*padded_args, *args[3:]) else: input_ids, attention_mask, *rest_of_args = args pad_len = self.neuron_config.sequence_length - attention_mask.shape[1] padded_attention_mask = F.pad(attention_mask, (0, pad_len), "constant", 0) args = (input_ids, padded_attention_mask, *rest_of_args) return args def _get_async_output(self, ranked_async_tensor): outputs = [[async_tensor[0].cpu()] for async_tensor in ranked_async_tensor] return outputs[0][0] def forward(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params): input_ids, attention_mask, position_ids, seq_ids = self.convert_int64_to_int32( input_ids, attention_mask, position_ids, seq_ids ) input_ids, attention_mask, position_ids, seq_ids = self.pad_to_max_compiled_seq( input_ids, attention_mask, position_ids, seq_ids ) input_batch_size = seq_ids.shape[0] if input_batch_size > self.neuron_config.max_batch_size: raise ValueError( f"Input batch size {input_batch_size} exceeds the maximum batch size {self.neuron_config.max_batch_size}." ) elif input_batch_size == self.neuron_config.batch_size: return self._forward(input_ids, attention_mask, position_ids, seq_ids, sampling_params) cur_batch = 0 output_logits = [] logging.debug( f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}" ) args = (input_ids, attention_mask, position_ids, seq_ids, sampling_params) while cur_batch < input_batch_size: if cur_batch + self.neuron_config.batch_size <= input_batch_size: # we only process part of the input to run logging.debug(f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}") outputs = self._forward(*[arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args]) else: # we need to pad the input to run logging.debug( f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}" ) outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args]) output_logits.append(outputs) cur_batch += self.neuron_config.batch_size if self.async_mode: # block on all requests here, since this is output manipulation output_logits = [self._get_async_output(ranked_logits) for ranked_logits in output_logits] return torch.cat(output_logits, dim=0) class DecoderModelInstance(BaseModelInstance): def __init__( self, model_cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: List[int], **kwargs ): self.model_cls = model_cls self.module = None self.input_output_aliases = None self.config = config self.neuron_config = neuron_config self.buckets = buckets self.kwargs = kwargs if kwargs is not None else {} def initialize_process_group(self, world_size): self.model_cls.initialize_process_group(world_size) def load_module(self): float_model = self.model_cls(self.config, self.neuron_config, **self.kwargs) float_model.eval() if self.neuron_config.torch_dtype != torch.float32: float_model._apply( lambda t: t.to(self.neuron_config.torch_dtype) if t.is_floating_point() and t.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else t ) self.module = float_model def get(self, bucket_rank, **kwargs): if bucket_rank is not None: self.module.n_positions = self.buckets[bucket_rank] # Currently we have to init an input_output_aliases map for # each buckets, otherwise it will fail the aliasing setup when # generating HLO self.input_output_aliases = {} num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 # TODO: This else block is a short-term fix for Llava/ViT models to use DecoderModelInstance. # Long-term, these models should use a different implementation of BaseModelInstance. if self.module.kv_mgr is not None: past_key_values = self.module.kv_mgr.past_key_values else: past_key_values = self.module.past_key_values for i in range(len(past_key_values)): self.input_output_aliases[past_key_values[i]] = num_output_from_trace + i return self.module, self.input_output_aliases