optimum/neuron/generation/utils.py (1,056 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for generation with Neuron.""" import copy import inspect import warnings from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from optimum.neuron.utils.import_utils import is_torch_xla_available from ..utils.import_utils import is_neuronx_distributed_available from ..utils.misc import args_and_kwargs_to_kwargs_only if is_torch_xla_available(): import torch_xla.core.xla_model as xm from transformers import GenerationMixin from transformers.generation.beam_search import BeamScorer, BeamSearchScorer from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import ( LogitsProcessorList, ) from transformers.generation.stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, StoppingCriteriaList, validate_stopping_criteria, ) from transformers.generation.utils import ( BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSearchOutput, GenerateOutput, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, GreedySearchOutput, ) from transformers.utils import ModelOutput, logging logger = logging.get_logger(__name__) if is_neuronx_distributed_available(): from neuronx_distributed.parallel_layers import parallel_state def _move_dict_args_to_device(kwargs: Dict[str, Any], device: str = "cpu") -> Dict[str, Any]: """ Takes keyword arguments which will be passed to a model's forward function and moves its values to `device` if they are of type `torch.Tensor`. If the key is a dictionary it does the same to the respective values. Args: kwargs: (`Dict[str, Any]`): The kwargs to be passed to the models forward function. device: (`str`, defaults to `cpu`): The target device to which tensors should be moved. Returns: `Dict[str, Any]`: The kwargs dict with its tensors moved to `device`. """ def needs_move(src_device, tgt_device): return src_device != tgt_device for k, v in kwargs.items(): # Handle nested dicts if isinstance(v, dict): for k_, v_ in v.items(): if isinstance(v_, torch.Tensor): if needs_move(v_.device, device): v[k_] = v_.to(device=device) # Handle tensor types elif isinstance(v, torch.Tensor): if needs_move(v.device, device): kwargs[k] = v.to(device=device) # Handle past_key_value tuples elif k == "past_key_values": if v is not None: new_past_key_values = () for layer_past in v: new_layer_past = () for past_state in layer_past: if needs_move(past_state.device, device): new_layer_past += (past_state.to(device=device),) else: new_layer_past += (past_state,) new_past_key_values += (new_layer_past,) kwargs[k] = new_past_key_values return kwargs def _pad_input_ids_for_general_sampling( input_ids: torch.Tensor, num_padding_values: int, pad_token_id: int ) -> torch.Tensor: """ Pads `input_ids` with `num_padding_values` padding tokens along the second dimension. Args: input_ids (`torch.Tensor`): Input ids to be padded. num_padding_values (`int`): Number of padding values to add. pad_token_id (`int`): Token ID of padding token. Returns: `torch.Tensor`: Padded `input_ids`. """ bsz = input_ids.size(0) input_ids = torch.cat( [input_ids, torch.ones((bsz, num_padding_values), device=input_ids.device, dtype=torch.long) * pad_token_id], 1 ) return input_ids def _get_fwd_for_general_sampling( current_fwd: Callable, generation_config: GenerationConfig, is_encoder_decoder: bool, vocab_size: int, main_device: str, to_device: str = "cpu", output_dtype: torch.dtype = torch.float32, ) -> Callable: """ Wraps the passed forward function and extends it such that before each forward call the `decoder_input_ids` are padded and all tensors are moved to `main_device` (e.g. XLA). Then the original forward passed is called followed by a `xm.mark_step`. Subsequently, an "unpadding" of the logits is performed. This way, all functions that process the logits can be called without making any changes. Args: current_fwd (`Callable`): The current forward function of the model. generation_config (`GenerationConfig`): The GenerationConfig of the model. is_encoder_decoder (`bool`): Defines if this is a encoder-decoder model. vocab_size (`int`): The total number of vocabs of the current model. main_device (`str`): The device on which the forward pass should be executed. to_device (`str`, defaults to `cpu`): The device on which all other processing should be executed. output_dtype (`torch.dtype`, defaults to `torch.float32`): The expected data type of the output logits. Returns: `Callable`: The extended forward function. """ @wraps(current_fwd) def new_fwd(*args, **kwargs): # Pad input to max length cur_len = None input_ids_string = "decoder_input_ids" if is_encoder_decoder else "input_ids" if input_ids_string in kwargs: current_input_ids = kwargs[input_ids_string] batch_size, cur_len = current_input_ids.shape num_padding_values = generation_config.max_length - cur_len kwargs[input_ids_string] = _pad_input_ids_for_general_sampling( current_input_ids, num_padding_values, generation_config.pad_token_id ) # For decoder only models, pad decoder attention mask in addition to prompts if "attention_mask" in kwargs and not is_encoder_decoder and num_padding_values > 0: kwargs["attention_mask"] = torch.cat( [ kwargs["attention_mask"], torch.zeros((batch_size, (generation_config.max_length - cur_len))) .long() .to(kwargs["attention_mask"].device), ], 1, ) # create position_ids on the fly for batch generation if "position_ids" in set(inspect.signature(current_fwd).parameters.keys()): position_ids = kwargs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(kwargs["attention_mask"] == 0, 1) kwargs["position_ids"] = position_ids # Move inputs to device _move_dict_args_to_device(kwargs, main_device) # Forward kwargs = args_and_kwargs_to_kwargs_only(current_fwd, args, kwargs) outputs = current_fwd(**kwargs) # Gather outputs if NxD tensor parallelism is applied and the output logits have not been gathered. if ( is_neuronx_distributed_available() and parallel_state.model_parallel_is_initialized() and parallel_state.get_tensor_model_parallel_size() > 1 and outputs["logits"].shape[-1] != vocab_size ): outputs["logits"] = xm.all_gather( outputs["logits"], dim=-1, groups=parallel_state.get_tensor_model_parallel_group(as_list=True), ) xm.mark_step() # Move to CPU _move_dict_args_to_device(outputs, to_device) # Post-process output as a function of cur_len outputs["logits"] = outputs["logits"][:, :cur_len, ...].to(output_dtype) return outputs return new_fwd class GeneralNeuronGenerationMixin(GenerationMixin): """ A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`]. The generation will be handled on both CPU and TRN1 in the following way: 1. Model forward pass will be executed on TRN1 2. All other logics including padding, searching, and sampling will be handled by general device (CPU). This implementation allows us to support general searching and sampling methods with minimal code changes. """ @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, **kwargs, ): # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, # two conditions must be met # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same). if self.generation_config._from_model_config: new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." " Please use and modify the model generation configuration (see" " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id # 3. Define model inputs and move to CPU general_device = "cpu" if "input_ids" in kwargs and kwargs["input_ids"] is not None: kwargs["input_ids"] = kwargs["input_ids"].to(general_device) if inputs is not None: inputs = inputs.to(general_device) input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) # 4. Set Neuron specific generation configurations original_forward = copy.deepcopy(self.forward) try: general_forward = _get_fwd_for_general_sampling( self.forward, generation_config, self.config.is_encoder_decoder, self.config.vocab_size, self.device, ) self.forward = general_forward if generation_config.use_cache: warnings.warn( "use_cache is not supported for generation on Neuron devices, switching to use_cache=False." ) # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": raise ValueError("Decoder-only models with inputs_embeds forwarding must use `use_cache=True`") generation_config.use_cache = False if generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[-1] # 5. Run HuggingFace generate function return super().generate(inputs, generation_config, **kwargs) finally: self.forward = original_forward def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: """Move the input tensor to XLA device and move the output tensors back to CPU.""" output = super()._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor.to(self.device), model_kwargs, model_input_name ) _move_dict_args_to_device(output, "cpu") return output class NeuronGenerationMixin(GenerationMixin): """ A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`]. The class exposes [`~generation.GenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False` - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and `top_k>1` - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True` - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and `do_sample=False` - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1` and `do_sample=True` - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1` and `num_beam_groups>1` - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if `constraints!=None` or `force_words_ids!=None` You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ @staticmethod def _initialize_attention( model_kwargs, num_padding_values, batch_size, device, is_encoder_decoder, ): """Initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`""" if is_encoder_decoder: # One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor, # 1s for the actual input_ids decoder_attention_mask = torch.cat( [ torch.zeros((batch_size, num_padding_values), dtype=torch.int32), torch.ones((batch_size, 2), dtype=torch.int32), ], axis=1, ).to(device) mask = {"decoder_attention_mask": decoder_attention_mask} else: attention_mask = model_kwargs.pop("attention_mask") # 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids attention_mask = torch.cat( [ torch.zeros( (batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device ), attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device), ], axis=1, ) mask = {"attention_mask": attention_mask} return mask @staticmethod def _update_attention(model_kwargs, batch_size, is_encoder_decoder): """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`""" attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" attention_mask = model_kwargs.pop(attention_mask_name) attention_mask_update_slice = torch.ones( (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device ) attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1) mask = {attention_mask_name: attention_mask} return mask @staticmethod def _initialize_past(past_key_values, num_padding_values): """Initialize past_key_values with zeros -- the structure depends on `batch_axis`""" new_past = () for past_layer in past_key_values: new_past_layer = list(past_layer) for i in range(len(new_past_layer[:2])): b, n_heads, _, head_dim = past_layer[i].shape new_past_layer[i] = torch.cat( [ torch.zeros( (b, n_heads, num_padding_values, head_dim), dtype=past_layer[i].dtype, device=past_layer[i].device, ), past_layer[i], ], dim=2, ) new_past += (tuple(new_past_layer),) return new_past @staticmethod def _update_past(past_key_values): new_past = () for past_layer in past_key_values: new_past_layer = list(past_layer) for i, _ in enumerate(new_past_layer[:2]): new_past_layer[i] = past_layer[i][:, :, 1:] new_past += (tuple(new_past_layer),) return new_past def _update_model_kwargs_for_xla_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], batch_size: int, is_encoder_decoder: bool = False, standardize_cache_format: bool = False, max_length: Optional[int] = None, seq_length: Optional[int] = None, use_cache: bool = True, ) -> Dict[str, Any]: if use_cache: past_key_values = self._extract_past_from_model_output(outputs) if past_key_values is None: raise ValueError( "No known `past_key_values variable` found in model outputs (model outputs keys:" f" {list(outputs.keys())})" ) is_past_initialized = model_kwargs.pop("past_key_values", None) is not None if not is_past_initialized: # The padded version of `past_key_values` has a length of `max_length - 1`, as `past_key_values` holds information relative to # previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step # has `max_length - 1` past_key_values values). num_padding_values = max_length - seq_length mask = self._initialize_attention( model_kwargs, num_padding_values, batch_size, outputs.logits.device, is_encoder_decoder ) new_past = self._initialize_past(past_key_values, num_padding_values) else: mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder) new_past = self._update_past(past_key_values) # sets the updated variables (mask and past_key_values) model_kwargs.update(mask) model_kwargs["past_key_values"] = tuple(new_past) else: model_kwargs["past_key_values"] = None if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat( [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 ) if not is_encoder_decoder: # update attention mask if "attention_mask" in model_kwargs: batch_size = model_kwargs["attention_mask"].shape[0] update_indices = torch.stack( [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 ) model_kwargs["attention_mask"][update_indices[:, 0], update_indices[:, 1]] = model_kwargs[ "attention_mask" ].new_ones((batch_size, 1)) else: # update decoder attention mask if "decoder_attention_mask" in model_kwargs: batch_size = model_kwargs["decoder_attention_mask"].shape[0] update_indices = torch.stack( [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 ) model_kwargs["decoder_attention_mask"][update_indices[:, 0], update_indices[:, 1]] = model_kwargs[ "decoder_attention_mask" ].new_ones((batch_size, 1)) return model_kwargs @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor): if len(dict_to_expand[key].shape) == 2: dict_to_expand[key] = ( dict_to_expand[key].repeat(1, expand_size).view(-1, dict_to_expand[key].shape[1]) ) elif len(dict_to_expand[key].shape) <= 1: dict_to_expand[key] = dict_to_expand[key].repeat(expand_size) else: dict_to_expand[key] = torch.concat( [tensor.unsqueeze(0).repeat(expand_size, 1, 1) for tensor in dict_to_expand[key]] ) return dict_to_expand if input_ids is not None: # Manual repeat interleave input_ids = input_ids.repeat(1, expand_size).view(-1, input_ids.shape[1]) model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, is_traced_inference: bool = False, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head. <Tip warning={true}> Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](../generation_strategies). </Tip> Parameters: inputs (`Optional[torch.Tensor]`, defaults to `None`): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`Optional[GenerationConfig]`, defaults to `None`): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`Optional[LogitsProcessorList]`, defaults to `None`): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`Optional[StoppingCriteriaList]`, defaults to `None`): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`Optional[bool]`, defaults to `None`): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. is_traced_inference (`bool`, defaults to `False`): Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores are computed inside the decoder. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GreedySearchDecoderOnlyOutput`], - [`~generation.SampleDecoderOnlyOutput`], - [`~generation.BeamSearchDecoderOnlyOutput`], - [`~generation.BeamSampleDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.SampleEncoderDecoderOutput`], - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: # legacy: users may modify the model configuration to control generation -- update the generation config # model attribute accordingly, if it was created from the model config if self.generation_config._from_model_config: new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." " Please use a generation configuration file (see" " https://huggingface.co/docs/transformers/main_classes/text_generation)" ) self.generation_config = new_generation_config generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states if generation_config.use_cache and not is_traced_inference: warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.") model_kwargs["use_cache"] = False else: model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs and not is_traced_inference if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: if ( generation_config.pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs and not is_traced_inference: # if model is encoder decoder encoder_outputs are created # and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" f" the maximum length ({generation_config.max_length})" ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) # Pad to max_length input_ids = torch.cat( [ input_ids, ( torch.ones( (batch_size, (generation_config.max_length - input_ids_seq_length)), ) .long() .to(input_ids.device) ) * generation_config.pad_token_id, ], 1, ) # For decoder only models, pad decoder attention mask in addition to prompts if ( "attention_mask" in model_kwargs and model_kwargs.get("use_cache", False) is False and not self.config.is_encoder_decoder ): model_kwargs["attention_mask"] = torch.cat( [ model_kwargs["attention_mask"], torch.zeros((batch_size, (generation_config.max_length - input_ids_seq_length))) .long() .to(model_kwargs["attention_mask"].device), ], 1, ) # 7. determine generation mode is_constraint_gen_mode = ( generation_config.constraints is not None or generation_config.force_words_ids is not None ) is_contrastive_search_gen_mode = ( (generation_config.num_beams == 1) and generation_config.top_k is not None and generation_config.top_k > 1 and generation_config.do_sample is False and generation_config.penalty_alpha is not None and generation_config.penalty_alpha > 0 ) is_greedy_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if hasattr(self, "device") and self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.", UserWarning, ) # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing greedy search, " f"but is {generation_config.num_return_sequences}." ) # 11. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, seq_length=input_ids_seq_length, is_traced_inference=is_traced_inference, **model_kwargs, ) elif is_beam_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, device="cpu", length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam search return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, seq_length=input_ids_seq_length, is_traced_inference=is_traced_inference, **model_kwargs, ) else: raise ValueError("Only greedy search and beam search are supported on Neuron.") def greedy_search( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, seq_length: Optional[int] = None, is_traced_inference: bool = False, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. <Tip warning={true}> In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). </Tip> Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) seq_length (`Optional[int]`, defaults to `False`): Length of current input_ids sequence is_traced_inference (`bool`, defaults to `False`): Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores are computed inside the decoder. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import AutoTokenizer >>> from optimum.neuron import NeuronModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1} >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes) >>> input_prompt = "translate English to German: Lets eat good food." >>> inputs = tokenizer(input_prompt, return_tensors="pt") >>> outputs = model.greedy_search(input_ids) >>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs] ``` """ # init values if logits_processor is not None and is_traced_inference: logger.warning( "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." ) elif logits_processor is None: logits_processor = LogitsProcessorList() use_cache = model_kwargs.pop("use_cache", False) stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = None if return_dict_in_generate and output_scores: if is_traced_inference: logger.warning( "`output_scores` will be neglected because currently we do not trace `next_token_scores` for greedy search (we do only in beam search). If you want us to support the option during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." ) else: scores = () decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs if use_cache: # From max_length-sized input_ids, select first # seq_length - 1 values. if model_kwargs.get("past_key_values") is None: input_ids_ = input_ids[:, :seq_length] else: update_indices = torch.stack( [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))], dim=-1, ) input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) else: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need if not is_traced_inference: if not use_cache: one_hot = ( torch.cat( [ torch.tensor([0]).repeat(1, seq_length - 1), torch.tensor([1]).repeat(1, 1), torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length), ], dim=1, ) .to(device=outputs.logits.device) .float() ) next_token_logits = torch.matmul(one_hot, outputs.logits) next_token_logits = next_token_logits.squeeze(1) else: next_token_logits = outputs.logits[:, -1, :] # pre-process distribution # Move to cpu to handle arbitrary logits_processor next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu")) next_tokens_scores = next_tokens_scores.to(input_ids.device) # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) if return_dict_in_generate and output_scores: scores += (next_tokens_scores,) else: next_tokens = outputs[0] # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step batch_size, _ = input_ids.shape update_indices = torch.stack( [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 ) input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:] model_kwargs = self._update_model_kwargs_for_xla_generation( outputs=outputs, model_kwargs=model_kwargs, batch_size=batch_size, is_encoder_decoder=self.config.is_encoder_decoder, max_length=stopping_criteria.max_length, seq_length=seq_length, use_cache=use_cache, ) seq_length += 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) if not is_traced_inference: xm.mark_step() # stop when each sentence is finished, or if we exceed the maximum length stop_criterion_1 = unfinished_sequences.max() == 0 if isinstance(stopping_criteria, list): if len(stopping_criteria) == 1: stopping_criteria = stopping_criteria[0] # Cases that can be handled in XLA without requiring # non-padded input_ids if isinstance(stopping_criteria, MaxLengthCriteria): stop_criterion_2 = seq_length >= stopping_criteria.max_length elif isinstance(stopping_criteria, MaxTimeCriteria): stop_criterion_2 = stopping_criteria(input_ids, scores) else: # Other cases will be handled on CPU batch_size, _ = input_ids.shape mask = torch.cat( [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)], dim=1, ).bool() input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu") scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) if stop_criterion_1 or stop_criterion_2: this_peer_finished = True if this_peer_finished and not synced_gpus: break if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return GreedySearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, seq_length: Optional[int] = None, is_traced_inference: bool = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. <Tip warning={true}> In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). </Tip> Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) seq_length (`Optional[int]`, defaults to `False`): Length of current input_ids sequence is_traced_inference (`bool`, defaults to `False`): Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores are computed inside the decoder. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import AutoTokenizer >>> from optimum.neuron import NeuronModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 4} >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes) >>> input_prompt = "translate English to German: Lets eat good food." >>> inputs = tokenizer(input_prompt, return_tensors="pt") >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... num_beams=num_beams, ... device=model.device, ... ) >>> outputs = model.beam_search(input_ids, beam_scorer) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ``` """ # init values if logits_processor is not None and is_traced_inference: logger.warning( "`logits_processor` will be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." ) elif logits_processor is None: logits_processor = LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape # Overwrite cur_len cur_len = seq_length if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores_device = "cpu" beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs if model_kwargs["use_cache"]: # From max_length-sized input_ids, select first # cur_len - 1 values. update_indices = torch.stack( [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1 ) input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) else: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) if is_traced_inference: outputs = self( **model_inputs, beam_scores=beam_scores, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) next_token_scores = outputs.next_token_scores next_tokens = outputs.next_tokens next_indices = outputs.next_indices if return_dict_in_generate and output_scores: scores += (next_token_scores,) else: outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need if not model_kwargs["use_cache"]: one_hot = ( torch.cat( [ torch.tensor([0]).repeat(1, cur_len - 1), torch.tensor([1]).repeat(1, 1), torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len), ], dim=1, ) .to(device=outputs.logits.device) .float() ) next_token_logits = torch.matmul(one_hot, outputs.logits) next_token_logits = next_token_logits.squeeze(1) else: next_token_logits = outputs.logits[:, -1, :] # Manually compute log softmax # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi)))) logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True) logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True)) next_token_scores = next_token_logits - logit_max - logsumexp # (batch_size * num_beams, vocab_size) xm.mark_step() # We don't want to change every single logit processor, so # we perform this processing on CPU. input_ids_ = input_ids.to("cpu")[:, :cur_len] next_token_scores_ = next_token_scores.to("cpu") next_token_scores_processed = logits_processor(input_ids_, next_token_scores_) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores * 1 # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size if return_dict_in_generate and output_scores: scores += (next_token_scores_processed,) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # stateless beam_outputs = beam_scorer.process( input_ids.to("cpu")[:, :cur_len], next_token_scores.to("cpu"), next_tokens.to("cpu"), next_indices.to("cpu"), pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] update_indices = torch.stack( [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1 ) update_indices_2 = torch.stack( [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1 ) # First select beam_indices device = input_ids.device beam_idx_device = beam_idx.to(device=input_ids.device) input_ids[:, :] = input_ids[beam_idx_device.long(), :] # Then append new tokens if is_traced_inference: # int64 is not natively supported by inf2 and has been cast down to int32 input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = ( beam_next_tokens.unsqueeze(-1).to(device).to(torch.long) ) else: input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to( device ) input_ids = input_ids * 1 # Hack to materialize tensor # update generated ids, model inputs, and length for next step model_kwargs = self._update_model_kwargs_for_xla_generation( outputs=outputs, model_kwargs=model_kwargs, batch_size=batch_beam_size, is_encoder_decoder=self.config.is_encoder_decoder, max_length=stopping_criteria.max_length, seq_length=cur_len, use_cache=model_kwargs["use_cache"], ) if is_traced_inference: self._reorder_cache(beam_idx.to(torch.int64)) elif model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 # stop when each sentence is finished, or if we exceed the maximum length stop_criterion_1 = beam_scorer.is_done if isinstance(stopping_criteria, list): if len(stopping_criteria) == 1: stopping_criteria = stopping_criteria[0] # Cases that can be handled in XLA without requiring # non-padded input_ids if isinstance(stopping_criteria, MaxLengthCriteria): stop_criterion_2 = cur_len >= stopping_criteria.max_length elif isinstance(stopping_criteria, MaxTimeCriteria): stop_criterion_2 = stopping_criteria(input_ids, scores) else: # Other cases will be handled on CPU batch_size, _ = input_ids.shape input_ids_cpu = input_ids.to("cpu") mask = torch.cat( [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1 ).bool() input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len)) scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) # TODO: validate with @JingyaHuang if stop_criterion_1 or torch.all(stop_criterion_2): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids.to("cpu"), beam_scores.to("cpu"), next_tokens.to("cpu"), next_indices.to("cpu"), pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) for k, v in sequence_outputs.items(): if type(v) is torch.Tensor: sequence_outputs[k] = sequence_outputs[k].to(input_ids.device) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"]