optimum/neuron/models/inference/backend/modules/generation/generation_utils.py (298 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Any, Dict, Optional, Union
import torch
from transformers import GenerationConfig
from transformers.generation import GenerationMixin, SampleDecoderOnlyOutput
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.modeling_outputs import ModelOutput
from .sampling import (
Sampler,
prepare_sampling_params,
)
class NxDGenerationMixin(GenerationMixin):
"""A generation Mixin that can be used to extend NxDPreTrainedModel based classes"""
# These are expected to be set by the GenerationMixin code
main_input_name = "input_ids"
_is_stateful = False
_supports_cache_class = False
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
assert hasattr(self, "neuron_config") # Must be set by the super class
# Initialize default generation config
self.generation_config = GenerationConfig.from_model_config(config)
self.sampler = None
def can_generate(self):
# Still required in transformers <= 4.50
return True
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
generation_config: Optional["GenerationConfig"] = None,
**kwargs,
):
# Sanity check
batch_size, sequence_length = input_ids.shape
if sequence_length > self.neuron_config.sequence_length:
raise ValueError(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.neuron_config.sequence_length})"
)
if batch_size > self.neuron_config.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.neuron_config.batch_size})"
)
# Keep generation stateless.
self.reset()
return super().generate(
input_ids, attention_mask=attention_mask, generation_config=generation_config, **kwargs
)
# TODO: Remove _sample and define separate flow for on-device sampling that doesn't use HF.
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs,
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
r"""
We override the GenerationMixin sample function (_sample for transformers>=4.39.0) to add support for right side padding.
"""
# init values
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = generation_config._pad_token_tensor
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
batch_size = model_kwargs["attention_mask"].shape[0]
top_k = generation_config.top_k if do_sample else 1
top_p = generation_config.top_p if do_sample else 1.0
temperature = generation_config.temperature if do_sample else 1.0
sampling_params = prepare_sampling_params(
batch_size=batch_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
model_kwargs["sampling_params"] = sampling_params
# init scores / logits tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False
is_for_token_generation = False
# auto-regressive generation
while not this_peer_finished:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(
input_ids, is_decode=is_for_token_generation, **model_kwargs
)
model_kwargs["attention_mask"] = model_inputs.get("attention_mask")
# forward pass to get next token
outputs = self.forward(**model_inputs, return_dict=True)
if self.neuron_config.on_device_sampling:
next_tokens = outputs.tokens
else:
next_token_logits = outputs.logits[:, -1, :].clone()
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if self.sampler is None:
self.sampler = Sampler(self.neuron_config, do_sample=True, on_cpu=True)
next_tokens = self.sampler(next_token_scores, sampling_params)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
is_for_token_generation = True
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_for_token_generation=is_for_token_generation
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
this_peer_finished = unfinished_sequences.max() == 0
if return_dict_in_generate:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
)
else:
return input_ids
def prepare_inputs_for_generation(
self,
input_ids,
is_decode,
attention_mask=None,
sampling_params=None,
seq_ids=None,
**kwargs,
):
if is_decode:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if is_decode:
position_ids = torch.amax(position_ids, 1, keepdim=True)
if seq_ids is None:
seq_ids = torch.arange(input_ids.shape[0])
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"sampling_params": sampling_params,
"seq_ids": seq_ids,
}
# WARNING: This is needed for propagating additional kwargs to the neuron model
additional_kwargs = self.get_required_kwargs()
for arg in additional_kwargs:
model_inputs.update({arg: kwargs.get(arg, None)})
return model_inputs
def prepare_inputs_for_prefill(
self,
input_ids,
attention_mask=None,
sampling_params=None,
**kwargs,
):
return self.prepare_inputs_for_generation(
input_ids,
is_decode=False,
attention_mask=attention_mask,
sampling_params=sampling_params,
**kwargs,
)
def prepare_inputs_for_decode(
self,
input_ids,
attention_mask=None,
sampling_params=None,
**kwargs,
):
return self.prepare_inputs_for_generation(
input_ids,
is_decode=True,
attention_mask=attention_mask,
sampling_params=sampling_params,
**kwargs,
)
# We override this function because we want to change the way attention_mask
# is updated each iteration.
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_for_token_generation: bool,
) -> Dict[str, Any]:
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
if is_for_token_generation:
if self.neuron_config.padding_side == "left":
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=-1,
)
attention_mask = attention_mask[:, 1:]
else:
attention_mask = torch.cat(
[attention_mask.new_ones((attention_mask.shape[0], 1)), attention_mask],
dim=-1,
)
model_kwargs["attention_mask"] = attention_mask
return model_kwargs
def _assisted_decoding(
self,
input_ids: torch.LongTensor,
candidate_generator: "CandidateGenerator", # noqa
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
**model_kwargs,
):
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
assistant_model = candidate_generator.assistant_model
if assistant_model.neuron_config.on_device_sampling:
raise ValueError("Assistant model must not use on-device sampling")
# Init values
if eos_token_id is not None and pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# Prepare assistant model's keys of inputs
assistant_kwargs = copy.deepcopy(model_kwargs)
# Other auxiliary variables
max_len = stopping_criteria[0].max_length
cur_len = input_ids.shape[-1]
spec_len = self.neuron_config.speculation_length
# Run the target model once and get the first generated token
model_inputs = self.prepare_inputs_for_prefill(input_ids, **model_kwargs)
outputs = self.forward(**model_inputs)
curr_pos = model_inputs["position_ids"][0].argmax(dim=-1)
new_token = outputs.logits[:, 0].argmax(dim=-1, keepdim=True)
# Prepare the input ids and attention mask for the draft model
candidate_input_ids = input_ids
# This is the finally return outputs; append the first generated token
returned_ids = torch.cat((input_ids[:, : curr_pos + 1], new_token), dim=1)
# Speculation loop
while True:
# 1 Token generation using draft model
is_for_token_generation = assistant_model.kv_cache_populated
for _ in range(spec_len):
# 1.1 Prepare assistant model inputs
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
is_decode=is_for_token_generation,
**assistant_kwargs,
)
# 1.2 Use the assistant model to obtain the next candidate logits
assistant_model_outputs = assistant_model.forward(**assistant_inputs)
assistant_new_token = assistant_model_outputs.logits[:, 0, :].argmax(dim=-1)
# 1.3 Update inputs and args for next iteration
candidate_input_ids = torch.cat((candidate_input_ids, assistant_new_token[:, None]), dim=-1)
assistant_kwargs = assistant_model._update_model_kwargs_for_generation(
assistant_model_outputs,
assistant_kwargs,
is_for_token_generation,
)
# 1.4 Stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = assistant_new_token.tile(eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
else:
last_assistant_token_is_eos = False
# 2 Validation of draft model output using the original model
# The length could be shorter if the draft loop ends earlier
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2.1 Prepare the input arguments
input_ids = torch.cat((new_token, candidate_input_ids[:, -candidate_length:-1]), dim=-1)
attention_mask = model_inputs["attention_mask"]
pos = curr_pos + 1
position_ids = torch.arange(pos, pos + spec_len).expand(1, spec_len)
# Pad the input_ids if needed
if input_ids.shape[-1] < spec_len:
input_ids = torch.cat(
(input_ids, torch.full((1, spec_len - input_ids.shape[-1]), pad_token_id)),
dim=-1,
)
# 2.2. Run a forward pass on the candidate sequence
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
# 2.3. Process the new logits
new_tokens = outputs.logits.argmax(dim=-1)
selected_tokens = outputs.logits[:, : candidate_length - 1].argmax(dim=-1)
# 3. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:-1]
n_matches = ((~(candidate_new_tokens == selected_tokens)).cumsum(dim=-1) < 1).sum()
# 4. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)
# n_matches = 4
# 5. Get the valid continuation, after the matching tokens. We also consider the extra token
# generated by the original model. Update the return ids accordingly
valid_tokens = new_tokens[:, : n_matches + 1]
returned_ids = torch.cat((returned_ids, valid_tokens), dim=1)
# if last_assistant_token_is_eos and n_matches == candidate_length-1:
# break;
# 6. Update the args for the next iteration.
# Feed the last correct token to the next loop
new_token = valid_tokens[:, -1:]
if new_token[0] in torch.tensor(eos_token_id):
break
input_ids = valid_tokens[:, -1:]
candidate_input_ids = valid_tokens[:, -1:]
model_inputs_attn_mask = model_inputs["attention_mask"]
n_matches_concat_tensor = torch.zeros(1, n_matches + 1, dtype=model_inputs_attn_mask.dtype)
model_inputs_attn_mask = torch.cat([model_inputs_attn_mask, n_matches_concat_tensor], dim=-1)
model_inputs["attention_mask"] = model_inputs_attn_mask.index_fill(
1, torch.arange(curr_pos + 1, curr_pos + 1 + n_matches + 1), 1
)
curr_pos = curr_pos + n_matches + 1
assistant_kwargs["attention_mask"] = copy.deepcopy(model_inputs["attention_mask"])
# 7. Update with the generated token length and check for stopping condition.
cur_len = cur_len + n_matches + 1
if cur_len >= max_len:
break
# 8. If the rest length is smaller than speculation length, we directly run the target model to finish
if max_len - cur_len < spec_len:
# @yihsian: TODO: complete with using target tokengen model
break
return returned_ids
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# We dont want HF to move parameters to device
return torch.device("cpu")