optimum/neuron/models/inference/backend/modules/generation/generation_utils.py (298 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Any, Dict, Optional, Union import torch from transformers import GenerationConfig from transformers.generation import GenerationMixin, SampleDecoderOnlyOutput from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.modeling_outputs import ModelOutput from .sampling import ( Sampler, prepare_sampling_params, ) class NxDGenerationMixin(GenerationMixin): """A generation Mixin that can be used to extend NxDPreTrainedModel based classes""" # These are expected to be set by the GenerationMixin code main_input_name = "input_ids" _is_stateful = False _supports_cache_class = False def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) assert hasattr(self, "neuron_config") # Must be set by the super class # Initialize default generation config self.generation_config = GenerationConfig.from_model_config(config) self.sampler = None def can_generate(self): # Still required in transformers <= 4.50 return True def generate( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, generation_config: Optional["GenerationConfig"] = None, **kwargs, ): # Sanity check batch_size, sequence_length = input_ids.shape if sequence_length > self.neuron_config.sequence_length: raise ValueError( f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.neuron_config.sequence_length})" ) if batch_size > self.neuron_config.batch_size: raise ValueError( f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.neuron_config.batch_size})" ) # Keep generation stateless. self.reset() return super().generate( input_ids, attention_mask=attention_mask, generation_config=generation_config, **kwargs ) # TODO: Remove _sample and define separate flow for on-device sampling that doesn't use HF. def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, logits_warper: Optional[LogitsProcessorList] = None, **model_kwargs, ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]: r""" We override the GenerationMixin sample function (_sample for transformers>=4.39.0) to add support for right side padding. """ # init values logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = generation_config._pad_token_tensor output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample batch_size = model_kwargs["attention_mask"].shape[0] top_k = generation_config.top_k if do_sample else 1 top_p = generation_config.top_p if do_sample else 1.0 temperature = generation_config.temperature if do_sample else 1.0 sampling_params = prepare_sampling_params( batch_size=batch_size, top_k=top_k, top_p=top_p, temperature=temperature, ) model_kwargs["sampling_params"] = sampling_params # init scores / logits tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False is_for_token_generation = False # auto-regressive generation while not this_peer_finished: # prepare model inputs model_inputs = self.prepare_inputs_for_generation( input_ids, is_decode=is_for_token_generation, **model_kwargs ) model_kwargs["attention_mask"] = model_inputs.get("attention_mask") # forward pass to get next token outputs = self.forward(**model_inputs, return_dict=True) if self.neuron_config.on_device_sampling: next_tokens = outputs.tokens else: next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if self.sampler is None: self.sampler = Sampler(self.neuron_config, do_sample=True, on_cpu=True) next_tokens = self.sampler(next_token_scores, sampling_params) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) is_for_token_generation = True model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_for_token_generation=is_for_token_generation ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) this_peer_finished = unfinished_sequences.max() == 0 if return_dict_in_generate: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, ) else: return input_ids def prepare_inputs_for_generation( self, input_ids, is_decode, attention_mask=None, sampling_params=None, seq_ids=None, **kwargs, ): if is_decode: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if is_decode: position_ids = torch.amax(position_ids, 1, keepdim=True) if seq_ids is None: seq_ids = torch.arange(input_ids.shape[0]) model_inputs = { "input_ids": input_ids, "position_ids": position_ids, "attention_mask": attention_mask, "sampling_params": sampling_params, "seq_ids": seq_ids, } # WARNING: This is needed for propagating additional kwargs to the neuron model additional_kwargs = self.get_required_kwargs() for arg in additional_kwargs: model_inputs.update({arg: kwargs.get(arg, None)}) return model_inputs def prepare_inputs_for_prefill( self, input_ids, attention_mask=None, sampling_params=None, **kwargs, ): return self.prepare_inputs_for_generation( input_ids, is_decode=False, attention_mask=attention_mask, sampling_params=sampling_params, **kwargs, ) def prepare_inputs_for_decode( self, input_ids, attention_mask=None, sampling_params=None, **kwargs, ): return self.prepare_inputs_for_generation( input_ids, is_decode=True, attention_mask=attention_mask, sampling_params=sampling_params, **kwargs, ) # We override this function because we want to change the way attention_mask # is updated each iteration. def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_for_token_generation: bool, ) -> Dict[str, Any]: if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] if is_for_token_generation: if self.neuron_config.padding_side == "left": attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1, ) attention_mask = attention_mask[:, 1:] else: attention_mask = torch.cat( [attention_mask.new_ones((attention_mask.shape[0], 1)), attention_mask], dim=-1, ) model_kwargs["attention_mask"] = attention_mask return model_kwargs def _assisted_decoding( self, input_ids: torch.LongTensor, candidate_generator: "CandidateGenerator", # noqa stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, **model_kwargs, ): pad_token_id = generation_config.pad_token_id eos_token_id = generation_config.eos_token_id assistant_model = candidate_generator.assistant_model if assistant_model.neuron_config.on_device_sampling: raise ValueError("Assistant model must not use on-device sampling") # Init values if eos_token_id is not None and pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None # Prepare assistant model's keys of inputs assistant_kwargs = copy.deepcopy(model_kwargs) # Other auxiliary variables max_len = stopping_criteria[0].max_length cur_len = input_ids.shape[-1] spec_len = self.neuron_config.speculation_length # Run the target model once and get the first generated token model_inputs = self.prepare_inputs_for_prefill(input_ids, **model_kwargs) outputs = self.forward(**model_inputs) curr_pos = model_inputs["position_ids"][0].argmax(dim=-1) new_token = outputs.logits[:, 0].argmax(dim=-1, keepdim=True) # Prepare the input ids and attention mask for the draft model candidate_input_ids = input_ids # This is the finally return outputs; append the first generated token returned_ids = torch.cat((input_ids[:, : curr_pos + 1], new_token), dim=1) # Speculation loop while True: # 1 Token generation using draft model is_for_token_generation = assistant_model.kv_cache_populated for _ in range(spec_len): # 1.1 Prepare assistant model inputs assistant_inputs = assistant_model.prepare_inputs_for_generation( candidate_input_ids, is_decode=is_for_token_generation, **assistant_kwargs, ) # 1.2 Use the assistant model to obtain the next candidate logits assistant_model_outputs = assistant_model.forward(**assistant_inputs) assistant_new_token = assistant_model_outputs.logits[:, 0, :].argmax(dim=-1) # 1.3 Update inputs and args for next iteration candidate_input_ids = torch.cat((candidate_input_ids, assistant_new_token[:, None]), dim=-1) assistant_kwargs = assistant_model._update_model_kwargs_for_generation( assistant_model_outputs, assistant_kwargs, is_for_token_generation, ) # 1.4 Stop assistant generation on EOS if eos_token_id_tensor is not None: last_assistant_token_is_eos = assistant_new_token.tile(eos_token_id_tensor.shape[0], 1) last_assistant_token_is_eos = ( ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() ) if last_assistant_token_is_eos: break else: last_assistant_token_is_eos = False # 2 Validation of draft model output using the original model # The length could be shorter if the draft loop ends earlier candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] # 2.1 Prepare the input arguments input_ids = torch.cat((new_token, candidate_input_ids[:, -candidate_length:-1]), dim=-1) attention_mask = model_inputs["attention_mask"] pos = curr_pos + 1 position_ids = torch.arange(pos, pos + spec_len).expand(1, spec_len) # Pad the input_ids if needed if input_ids.shape[-1] < spec_len: input_ids = torch.cat( (input_ids, torch.full((1, spec_len - input_ids.shape[-1]), pad_token_id)), dim=-1, ) # 2.2. Run a forward pass on the candidate sequence outputs = self.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, ) # 2.3. Process the new logits new_tokens = outputs.logits.argmax(dim=-1) selected_tokens = outputs.logits[:, : candidate_length - 1].argmax(dim=-1) # 3. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. candidate_new_tokens = candidate_input_ids[:, -candidate_length:-1] n_matches = ((~(candidate_new_tokens == selected_tokens)).cumsum(dim=-1) < 1).sum() # 4. Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 n_matches = min(n_matches, max_len - cur_len - 1) # n_matches = 4 # 5. Get the valid continuation, after the matching tokens. We also consider the extra token # generated by the original model. Update the return ids accordingly valid_tokens = new_tokens[:, : n_matches + 1] returned_ids = torch.cat((returned_ids, valid_tokens), dim=1) # if last_assistant_token_is_eos and n_matches == candidate_length-1: # break; # 6. Update the args for the next iteration. # Feed the last correct token to the next loop new_token = valid_tokens[:, -1:] if new_token[0] in torch.tensor(eos_token_id): break input_ids = valid_tokens[:, -1:] candidate_input_ids = valid_tokens[:, -1:] model_inputs_attn_mask = model_inputs["attention_mask"] n_matches_concat_tensor = torch.zeros(1, n_matches + 1, dtype=model_inputs_attn_mask.dtype) model_inputs_attn_mask = torch.cat([model_inputs_attn_mask, n_matches_concat_tensor], dim=-1) model_inputs["attention_mask"] = model_inputs_attn_mask.index_fill( 1, torch.arange(curr_pos + 1, curr_pos + 1 + n_matches + 1), 1 ) curr_pos = curr_pos + n_matches + 1 assistant_kwargs["attention_mask"] = copy.deepcopy(model_inputs["attention_mask"]) # 7. Update with the generated token length and check for stopping condition. cur_len = cur_len + n_matches + 1 if cur_len >= max_len: break # 8. If the rest length is smaller than speculation length, we directly run the target model to finish if max_len - cur_len < spec_len: # @yihsian: TODO: complete with using target tokengen model break return returned_ids @property def device(self) -> torch.device: """ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ # We dont want HF to move parameters to device return torch.device("cpu")