optimum/graphcore/generation/utils.py (1,179 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import copy import json import os import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import poptorch import torch from torch import nn from transformers.generation.stopping_criteria import validate_stopping_criteria from transformers.generation.utils import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, BeamSampleOutput, BeamScorer, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSearchOutput, GenerationMixin, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, GreedySearchOutput, LogitsProcessorList, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, SampleOutput, StoppingCriteriaList, ) from transformers.modeling_outputs import BaseModelOutput, ModelOutput from transformers.utils.versions import require_version from optimum.utils import logging from .logits_process import IPULogitsProcessors from .on_device_generation import ( OnDeviceGenerationModel, OnDeviceGenerationModelOutput, ) logger = logging.get_logger(__name__) MODELS_SUPPORTING_KV_CACHE = set() def supports_kv_cache(pipelined_cls): MODELS_SUPPORTING_KV_CACHE.add(pipelined_cls) return pipelined_cls def assert_poptorch_supports_cond(context: Optional[str] = None): context = context or "" require_version("poptorch>=3.3", "Require poptorch>=3.3 for `poptorch.cond`. " + context) if not hasattr(poptorch, "cond"): raise AttributeError( "`poptorch.cond` appears to be missing, perhaps you are using a candidate release " "which does not support it yet? " + context ) @contextlib.contextmanager def graph_profile_dir_append(append: str): if poplar_engine_options_original := os.getenv("POPLAR_ENGINE_OPTIONS"): poplar_engine_options_modified = json.loads(poplar_engine_options_original) if autoreport_directory := poplar_engine_options_modified.get("autoReport.directory"): poplar_engine_options_modified["autoReport.directory"] = autoreport_directory + append os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(poplar_engine_options_modified) try: yield finally: if poplar_engine_options_original: os.environ["POPLAR_ENGINE_OPTIONS"] = poplar_engine_options_original class _IndexedInputLinear(nn.Module): """ Wrapper layer for `Linear` that performs a `dynamic_slice` on the input before executing the linear. The intended use is as an optimized replacement of the LM Head in the Decoder for text generation inference when KV caching is disabled. The slice is performed on the position `self._generation_step` of the input tensor, where `self._generation_step` is a PyTorch buffer. """ def __init__(self, linear_layer): super().__init__() self.wrapped_linear = linear_layer self.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) def forward(self, x): x = poptorch.dynamic_slice(x, 1, self._generation_step, 1, 1) return self.wrapped_linear(x) class DecoderWrapper(nn.Module): """ Fast wrapper for decoder part of text generation models. Updates the appropriate buffers for the modules which need to know the current generation step. Only returns the logits from the last generated token to reduce IO costs. """ def __init__(self, pipelined_model): super().__init__() self.pipelined_model = pipelined_model # With KV caching, some modules may need to know the current decoding step and beam indices. # Getting this information to them can either be done by copying it into buffers, or # by subclassing the entire decoder model just to change the forward signatures and passing these # as arguments. For now, go with the former, but it's not set in stone. self._modules_with_attributes_in_buffers = { attr: [module for module in self.pipelined_model.modules() if hasattr(module, attr)] for attr in ["_beam_idx", "_generation_step"] } def register_encoder_output_buffers(self, output_buffers: Dict[str, torch.Tensor]): for name in sorted(output_buffers): self.register_buffer(name, output_buffers[name], persistent=False) def _get_buffered_outputs(self) -> Dict: kwargs = {} if hasattr(self, "encoder_last_hidden_state"): kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=self.encoder_last_hidden_state) if hasattr(self, "encoder_attention_mask"): kwargs["attention_mask"] = self.encoder_attention_mask return kwargs def forward(self, t, beam_idx=None, **model_inputs): """ Args: t : (`torch.Tensor(int)`) Tensor with single int representing the current length of the sequence being generated beam_idx: (`torch.LongTensor` of shape `(batch_size * num_beams,)`): Beam indices indicating to which beam the tokens were added, required for reordering the on-device KV cache. model_inputs : Regular model_inputs passed to the wrapped model. Returns: The output logits at position `t` only """ for module in self._modules_with_attributes_in_buffers["_generation_step"]: module._generation_step.copy_(t) # When generation is done on host, the beam_idx has to be provided as an input. # When generation is done on device, the beam_idx is stored in a separate buffer. if beam_idx is None: if hasattr(self.pipelined_model, "generation_strategy") and hasattr( self.pipelined_model.generation_strategy, "_cached_beam_idx" ): beam_idx = self.pipelined_model.generation_strategy._cached_beam_idx.int() for module in self._modules_with_attributes_in_buffers["_beam_idx"]: if beam_idx is None: raise ValueError( "A module registered a `beam_idx` buffer, but the pipelined model is not called with such, " "or the on device beam search did not register `_cached_beam_idx`. For the first case, " "`beam_idx` can be provided to the model via `prepare_inputs_for_generation`." ) module._beam_idx.copy_(beam_idx) # Run the decoder kwargs = self._get_buffered_outputs() outputs = self.pipelined_model(**model_inputs, **kwargs) if isinstance(outputs, ModelOutput) and not isinstance(outputs, OnDeviceGenerationModelOutput): outputs = type(outputs)( logits=outputs.logits, ) return outputs class IPUGenerationMixin(GenerationMixin): """ Enable optimization for encoder-decoder text generation where the encoder outputs are cached on the Decoder device using buffers. """ _use_cond_encoder = False _use_encoder_output_buffer = False kv_cache_enabled = False @property def encoder_output_buffer_enabled(self) -> bool: return self.config.is_encoder_decoder and self._use_encoder_output_buffer and not self._use_cond_encoder @property def cond_encoder_enabled(self) -> bool: return self.config.is_encoder_decoder and self._use_cond_encoder def _pad_tensors_to_max_len(self, tensor: torch.Tensor, max_length: int, pad_token_id: int) -> torch.Tensor: return nn.functional.pad(tensor, (0, max_length - tensor.shape[1]), "constant", pad_token_id) def _ensure_generation_step_progression(self, generation_step): if not self.kv_cache_enabled: return if not hasattr(self, "_previous_generation_step"): self._previous_generation_step = generation_step return if generation_step <= self._previous_generation_step and generation_step != 0: raise ValueError("`generation_step` must increase, or begin from 0.") self._previous_generation_step = generation_step def _call_generate( self, *args, generation_step: int, on_device_generation_model_ctr: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, **kwargs, ): self._ensure_generation_step_progression(generation_step) t = self._get_generation_step_tensor(generation_step, ascending=on_device_generation_model_ctr is not None) if not hasattr(self, "poptorch_decoder"): generation_model = self if on_device_generation_model_ctr is not None: generation_model = on_device_generation_model_ctr(self) decoder_wrapper = DecoderWrapper(generation_model.eval()) if os.getenv("DEBUG_RUN_DECODER_ON_CPU", False): self.poptorch_decoder = decoder_wrapper else: decoder_ipu_config = getattr(self, "decoder_ipu_config", self.ipu_config) decoder_options = decoder_ipu_config.to_options(for_inference=True) if self.encoder_output_buffer_enabled: require_version( "poptorch>=3.3", "Updatable encoder output buffer optimization only available in poptorch>=3.3" ) if decoder_ipu_config.inference_replication_factor > 1: raise ValueError("Replication is not supported when `use_encoder_output_buffer=True`.") named_buffers = {} encoder_last_hidden_state = kwargs.pop("encoder_outputs")["last_hidden_state"] if encoder_last_hidden_state is not None: named_buffers["encoder_last_hidden_state"] = encoder_last_hidden_state attention_mask = kwargs.pop("attention_mask", None) if attention_mask is not None: named_buffers["encoder_attention_mask"] = attention_mask.half() if not named_buffers: raise ValueError( "Found `encoder_output_buffer_enabled=True`, but encoder outputs missing when calling the model." ) decoder_wrapper.register_encoder_output_buffers(named_buffers) decoder_options.updatableNamedBuffers(list(named_buffers.keys())) if self.cond_encoder_enabled and decoder_ipu_config.inference_replication_factor > 1: decoder_options.broadcastBuffers(False) self.poptorch_decoder = poptorch.inferenceModel(decoder_wrapper, decoder_options) elif self.encoder_output_buffer_enabled: encoder_last_hidden_state = kwargs.pop("encoder_outputs")["last_hidden_state"] attention_mask = kwargs.pop("attention_mask", None) if generation_step == 0: self.poptorch_decoder.encoder_last_hidden_state.copy_(encoder_last_hidden_state) if attention_mask is not None: self.poptorch_decoder.encoder_attention_mask.copy_(attention_mask.half()) if self.poptorch_decoder.isCompiled() and not self.poptorch_decoder.isAttachedToDevice(): self.poptorch_decoder.attachToDevice() self.poptorch_decoder.copyNamedBuffersToDevice() # This will trigger a compile first time it's ran with graph_profile_dir_append("/decoder" if self.config.is_encoder_decoder else ""): return self.poptorch_decoder(*args, t=t, **kwargs) def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor if self.cond_encoder_enabled: # The encoder and decoder are being run on the same IPU. # We make a simplifying assumption and only provide the inputs to the encoder. model_kwargs[model_input_name] = inputs_tensor # For minimal changes to the generation path, we put dummy encoder outputs here # and drop them before calling the model. model_kwargs["encoder_outputs"] = BaseModelOutput( last_hidden_state=torch.zeros(inputs_tensor.shape[0], 1, dtype=encoder.dtype) ) return model_kwargs if not hasattr(self, "poptorch_encoder"): # Use split encoder ipu_config for encoder/decoder models if os.getenv("DEBUG_RUN_ENCODER_ON_CPU", False): self.poptorch_encoder = encoder.eval() else: self.poptorch_encoder = poptorch.inferenceModel( encoder.eval(), self.encoder_ipu_config.to_options(for_inference=True) ) with graph_profile_dir_append("/encoder"): model_kwargs["encoder_outputs"]: ModelOutput = self.poptorch_encoder(**encoder_kwargs) return model_kwargs @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: # Change: if we are running the encoder and decoder on the same IPU, `model_kwargs` # will contain `input_features`. We do not want these to be expanded by e.g. num_beams. input_features = model_kwargs.pop("input_features", None) input_ids, model_kwargs = GenerationMixin._expand_inputs_for_generation( expand_size=expand_size, is_encoder_decoder=is_encoder_decoder, input_ids=input_ids, **model_kwargs ) if input_features is not None: model_kwargs["input_features"] = input_features return input_ids, model_kwargs def detachFromDevice(self): if hasattr(self, "poptorch_encoder"): self.poptorch_encoder.detachFromDevice() if hasattr(self, "poptorch_decoder"): self.poptorch_decoder.detachFromDevice() def destroy(self): if hasattr(self, "poptorch_encoder"): self.poptorch_encoder.destroy() delattr(self, "poptorch_encoder") if hasattr(self, "poptorch_decoder"): self.poptorch_decoder.destroy() delattr(self, "poptorch_decoder") def _get_generation_step_tensor(self, generation_step, ascending=False): # Returns a 1 dimensional tensor of the form [device_iterations * replication factor] # with all elements equal to generation_step. # This ensures the dimensions are as expected by any parallelism options. decoder_ipu_config = getattr(self, "decoder_ipu_config", self.ipu_config) per_replica = ( torch.arange(decoder_ipu_config.inference_device_iterations) + generation_step if ascending else torch.ones(decoder_ipu_config.inference_device_iterations) * generation_step ) return per_replica.repeat_interleave(decoder_ipu_config.inference_replication_factor) def _populate_parallelize_kwargs_with_generation_config(self, **kwargs): if self.generation_config is None: return kwargs for kwarg in ["num_beams", "max_length"]: if kwarg not in kwargs: kwarg_value = getattr(self.generation_config, kwarg) kwargs[kwarg] = kwarg_value logger.info(f"Setting parallelize kwarg `{kwarg}` to value in generation_config ({kwarg_value}).") return kwargs def _validate_kv_cache(self, use_cache, num_beams=1, max_length=128): first_call = not hasattr(self, "_poptorch_decoder") if use_cache and self.__class__ not in MODELS_SUPPORTING_KV_CACHE: if first_call: logger.warn( f"{self.__class__} does not support KV caching, but `use_cache=True`. " "Overriding to `use_cache=False`. If your believe your pipelined model " "supports static KV caching, please decorate it using `supports_kv_cache`." ) use_cache = False if not use_cache or not first_call: return use_cache model_has_kv_cache_initialized = any(getattr(m, "kv_cache_initialized", False) for m in self.modules()) if use_cache and not model_has_kv_cache_initialized: raise ValueError( f"{self.__class__.__name__} supports KV caching and `use_cache=True`, but no KV caches have been initialized. " f"Please pass `use_cache=True` to the `parallelize` method of {self.__class__.__name__}." ) self.kv_cache_enabled = use_cache and model_has_kv_cache_initialized if not self.kv_cache_enabled: return use_cache module_with_cache = next(m for m in self.modules() if getattr(m, "kv_cache_initialized", False)) cache_shape = module_with_cache._k_cache.shape cache_num_beams = module_with_cache._num_beams cache_max_length = cache_shape[2] generic_kwarg_msg = ( "KV caches are created with `kwargs` that are directly provided to `parallelize`, or where such " "kwargs are missing, we optionally retrieve values from the `model.generation_config`. " "On the other hand, `model.generate()` will determine generation kwargs in the priority of " "`kwargs` > `kwargs['generation_config']` > `model.generation_config`. " "Mismatches between the two flows can be reconciled by ensuring that the kwargs provided to `parallelize` " "match the `kwargs` and / or `kwargs['generation_config']` passed to `model.generate()`." ) if cache_num_beams != num_beams: raise ValueError( f"KV caches were created with num_beams={cache_num_beams}, but `model.generate()` is being called " f"with {num_beams=}." f"\n{generic_kwarg_msg}" ) if cache_max_length != max_length: raise ValueError( f"KV caches were created with max_length={cache_max_length}, but `model.generate()` is being called " f"with {max_length=}." f"\n{generic_kwarg_msg}" ) return use_cache def change_lm_head_to_indexed_input_linear(self, restore: bool): """Changes the LM head with the faster _IndexedInputLinear layer. Args: restore: whether to restore the LM head to the original version or not. """ if restore: lm_head = self.get_output_embeddings() if lm_head.__class__ == _IndexedInputLinear: self.set_output_embeddings(lm_head.wrapped_linear) else: self.set_output_embeddings(_IndexedInputLinear(self.get_output_embeddings())) # Modified from https://github.com/huggingface/transformers/blob/v4.20.1/src/transformers/generation_utils.py#L1532 def greedy_search( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "It might be possible to" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> outputs = model.greedy_search( ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) else: max_length = stopping_criteria.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) use_cache = model_kwargs.get("use_cache", False) use_cache = self._validate_kv_cache(use_cache, num_beams=1, max_length=max_length) # Change: intercept to optionally run the entire generation loop on device if self.on_device_generation_steps > 0: return self._on_device_greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length, return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[-1] while True: # Change: remove synced_gpu code # Change: add input max_length padding if not use_cache: input_ids = self._pad_tensors_to_max_len(input_ids, stopping_criteria.max_length, pad_token_id) # Change: For a seq2seq model such as BART, the "attention_mask" is the encoder/cross attention mask and it does not require padding. if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = self._pad_tensors_to_max_len( model_kwargs["attention_mask"], stopping_criteria.max_length, 0 ) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self._call_generate( generation_step=cur_len - 1, **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # Change: Remove padding and restore to actual length if not use_cache: input_ids = input_ids[:, :cur_len] if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len] # Change: remove synced_gpu code next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_tokens_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): # Change: remove synced_gpu code break # End of while True if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return GreedySearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) else: max_length = stopping_criteria.max_length if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) use_cache = model_kwargs.get("use_cache", False) use_cache = self._validate_kv_cache(use_cache, num_beams=num_beams, max_length=max_length) # Change: intercept to optionally run the entire generation loop on device if self.on_device_generation_steps > 0: return self._on_device_beam_search( input_ids, beam_scorer=beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length, return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # Change: remove synced_gpu code # Change: add input max_length padding if not use_cache: input_ids = self._pad_tensors_to_max_len(input_ids, stopping_criteria.max_length, pad_token_id) # Change: For a seq2seq model such as BART, the "attention_mask" is the encoder/cross attention mask and it does not require padding. if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = self._pad_tensors_to_max_len( model_kwargs["attention_mask"], stopping_criteria.max_length, 0 ) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self._call_generate( generation_step=cur_len - 1, **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # Change: Remove padding and restore to actual length if not use_cache: input_ids = input_ids[:, :cur_len] if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len] # Change: remove synced_gpu code # Change: cast to float on cpu next_token_logits = outputs.logits[:, -1, :].float() # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) # Change: add beam_idx to model_kwargs so KV caching can be made aware of it on device model_kwargs["beam_idx"] = beam_idx if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): break sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList( ... [ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> outputs = model.sample( ... input_ids, ... logits_processor=logits_processor, ... logits_warper=logits_warper, ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) else: max_length = stopping_criteria.max_length logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) use_cache = model_kwargs.get("use_cache", False) use_cache = self._validate_kv_cache(use_cache, num_beams=1, max_length=max_length) # Change: intercept to optionally run the entire generation loop on device if self.on_device_generation_steps > 0: return self._on_device_sample() # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[-1] # auto-regressive generation while True: # Change: remove synced_gpu code # Change: add input max_length padding if not use_cache: input_ids = self._pad_tensors_to_max_len(input_ids, stopping_criteria.max_length, pad_token_id) # Change: For a seq2seq model such as BART, the "attention_mask" is the encoder/cross attention mask and it does not require padding. if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = self._pad_tensors_to_max_len( model_kwargs["attention_mask"], stopping_criteria.max_length, 0 ) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self._call_generate( generation_step=cur_len - 1, **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # Change: Remove padding and restore to actual length if not use_cache: input_ids = input_ids[:, :cur_len] if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len] # Change: remove synced_gpu code # Change: cast to float on cpu next_token_logits = outputs.logits[:, -1, :].float() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores.float(), dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **beam search multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] ... ) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList( ... [ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ] ... ) >>> outputs = model.beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) else: max_length = stopping_criteria.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) use_cache = model_kwargs.get("use_cache", False) use_cache = self._validate_kv_cache(use_cache, num_beams=num_beams, max_length=max_length) # Change: intercept to optionally run the entire generation loop on device if self.on_device_generation_steps > 0: return self._on_device_beam_sample() beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # Change: remove synced_gpu code # Change: add input max_length padding if not use_cache: input_ids = self._pad_tensors_to_max_len(input_ids, stopping_criteria.max_length, pad_token_id) # Change: For a seq2seq model such as BART, the "attention_mask" is the encoder/cross attention mask and it does not require padding. if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = self._pad_tensors_to_max_len( model_kwargs["attention_mask"], stopping_criteria.max_length, 0 ) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self._call_generate( generation_step=cur_len - 1, **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # Change: Remove padding and restore to actual length if not use_cache: input_ids = input_ids[:, :cur_len] if not self.config.is_encoder_decoder: model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len] # Change: remove synced_gpu code # Change: cast to float on cpu next_token_logits = outputs.logits[:, -1, :].float() # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (logits_warper(input_ids, next_token_scores_processed),) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) # Change: add beam_idx to model_kwargs so KV caching can be made aware of it on device model_kwargs["beam_idx"] = beam_idx if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): break sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSampleDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] on_device_generation_steps: int = 0 def set_on_device_generation_steps(self, value: Optional[int] = 0): self.on_device_generation_steps = value if value == 0: del self.on_device_generation_steps def _adapt_logits_processor_for_on_device_generation(self, logits_processor: LogitsProcessorList, vocab_size: int): adapted_processors = [] for processor in logits_processor: ipu_processor_cls = IPULogitsProcessors.get(processor.__class__, None) if ipu_processor_cls is None: raise NotImplementedError(f"{processor.__class__.__name__} is not supported yet to run on IPU.") try: ipu_processor = ipu_processor_cls.from_model(processor, vocab_size) except AttributeError: ipu_processor = copy.deepcopy(processor) ipu_processor.__class__ = ipu_processor_cls adapted_processors.append(ipu_processor) return LogitsProcessorList(adapted_processors) def _adapt_stopping_criteria_for_on_device_generation( self, stopping_criteria: StoppingCriteriaList, on_device_generation_steps: int ): adapted_stopping_criteria = [] for stopping_criterion in stopping_criteria: if hasattr(stopping_criterion, "max_length"): max_length = stopping_criterion.max_length new_max_length = max_length - on_device_generation_steps logger.debug( f"Temporarily adapting `max_length` from {max_length} to {new_max_length} for on device generation." ) stopping_criterion = copy.deepcopy(stopping_criterion) stopping_criterion.max_length = new_max_length adapted_stopping_criteria.append(stopping_criterion) return StoppingCriteriaList(adapted_stopping_criteria) def _prepare_inputs_for_on_device_generation(self, model_inputs, on_device_generation_steps, batch_size): """ A model-agnostic version of `prepare_inputs_for_generation` whose main purpose is to duplicate decoder inputs by `on_device_generation_steps=inference_device_iterations` and perform additional input validation. Since we are duplicating tensors, we restrict duplication to `torch.Tensor` and the exceptional case of `encoder_outputs.last_hidden_state`. """ adapted_model_inputs = {} for k, v in model_inputs.items(): if k in ("attention_mask", "encoder_outputs") and self.encoder_output_buffer_enabled: # These inputs will copied onto device via buffers, so we don't need to duplicate them. adapted_model_inputs[k] = v continue if k == "beam_idx": # With on-device generation, beam_idx at each step is handled through buffers. continue if k == "input_features" and self.cond_encoder_enabled: v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1))) elif torch.is_tensor(v): if v.shape[0] != batch_size: raise ValueError(f"Unexpected size in dim 0 for {k}, expected {batch_size}.") v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1))) elif k == "encoder_outputs": v_type = type(v) if not isinstance(v, BaseModelOutput): raise ValueError( "Expected `encoder_outputs` to be an instance of `BaseModelOutput`, " f"received {v_type}." ) v = v.last_hidden_state v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1))) v = v_type(last_hidden_state=v) elif v is None: pass elif isinstance(v, (int, float, str, bool)): pass else: raise TypeError( f"Unexpected type {type(v)} received for decoder input {k}. On device generation enforces " "stricter input validation to minimise unexpected errors. Improvements are always welcome." ) adapted_model_inputs[k] = v return adapted_model_inputs def _on_device_greedy_search( self, input_ids: torch.Tensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, pad_token_id: int, eos_token_id: int, max_length: int, return_dict_in_generate: Optional[bool] = False, **model_kwargs, ): if not model_kwargs.get("use_cache", False): raise NotImplementedError("On device greedy search assumes `use_cache=True`.") if return_dict_in_generate: raise NotImplementedError("On device greedy search assumes `return_dict_in_generate=False`.") batch_size, context_length = input_ids.shape vocab_size = self.get_output_embeddings().out_features if context_length > 1: raise ValueError("Context length (input_ids.shape[-1]) > 1 is not supported yet.") if (max_length - context_length) % self.on_device_generation_steps != 0: logger.debug( "`max_length - context_length` does not evenly divide `on_device_generation_steps` " f"({max_length - context_length} vs {self.on_device_generation_steps}). Generation will be done " f"{self.on_device_generation_steps} tokens at a time and stop short of `max_length` so as not to exceed it." ) decoder_ipu_config = getattr(self, "decoder_ipu_config", self.ipu_config) if decoder_ipu_config.inference_device_iterations not in (1, self.on_device_generation_steps): raise ValueError( "On device generation expects `inference_device_iterations=1` or " "`inference_device_iterations=on_device_generation_steps`, " f"received {self.ipu_config.inference_device_iterations}. " "For on device generation, `inference_device_iterations` will be set to " f"`on_device_generation_steps={self.on_device_generation_steps}`." ) if hasattr(self, "decoder_ipu_config"): self.decoder_ipu_config.inference_device_iterations = self.on_device_generation_steps else: self.ipu_config.inference_device_iterations = self.on_device_generation_steps logits_processor = self._adapt_logits_processor_for_on_device_generation(logits_processor, vocab_size) stopping_criteria = self._adapt_stopping_criteria_for_on_device_generation( stopping_criteria, self.on_device_generation_steps ) # This function only has to be called at the beginning of generation since # all necessary state should be stored in buffers on device. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # A model-agnostic version of above mainly to duplicate inputs for device iterations. model_inputs = self._prepare_inputs_for_on_device_generation( model_inputs, self.on_device_generation_steps, batch_size ) per_replica_batch_size = batch_size // decoder_ipu_config.inference_replication_factor def on_device_generation_model_ctr(inst): return OnDeviceGenerationModel( inst, batch_size=per_replica_batch_size, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, logits_processor=logits_processor, num_beams=1, use_cache=True, ) generation_step = 0 while True: output = self._call_generate( generation_step=generation_step, # NB: equal to `cur_len - 1` since context_length=1 on_device_generation_model_ctr=on_device_generation_model_ctr, **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) next_tokens = output.generated_tokens.view(self.on_device_generation_steps, batch_size).T done = torch.all( output.done.view(self.on_device_generation_steps, decoder_ipu_config.inference_replication_factor), dim=-1, ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens], dim=-1) generation_step += self.on_device_generation_steps # stop when each sentence is finished, or if we exceed the maximum length if torch.any(done) or stopping_criteria(input_ids, ()): break return input_ids def _on_device_beam_search( self, input_ids: torch.Tensor, beam_scorer: BeamScorer, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, pad_token_id: int, eos_token_id: int, max_length: int, return_dict_in_generate: Optional[bool] = False, **model_kwargs, ): if not model_kwargs.get("use_cache", False): raise NotImplementedError("On device beam search assumes `use_cache=True`.") if return_dict_in_generate: raise NotImplementedError("On device beam search assumes `return_dict_in_generate=False`.") batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, context_length = input_ids.shape vocab_size = self.get_output_embeddings().out_features if context_length > 1: raise ValueError("Context length (input_ids.shape[-1]) > 1 is not supported yet.") if (max_length - context_length) % self.on_device_generation_steps != 0: logger.debug( "`max_length - context_length` does not evenly divide `on_device_generation_steps` " f"({max_length - context_length} vs {self.on_device_generation_steps}). Generation will be done " f"{self.on_device_generation_steps} tokens at a time and stop short of `max_length` so as not to exceed it." ) decoder_ipu_config = getattr(self, "decoder_ipu_config", self.ipu_config) if decoder_ipu_config.inference_device_iterations not in (1, self.on_device_generation_steps): raise ValueError( "On device generation expects `inference_device_iterations=1` or " "`inference_device_iterations=on_device_generation_steps`, " f"received {self.ipu_config.inference_device_iterations}. " "For on device generation, `inference_device_iterations` will be set to " f"`on_device_generation_steps={self.on_device_generation_steps}`." ) if hasattr(self, "decoder_ipu_config"): self.decoder_ipu_config.inference_device_iterations = self.on_device_generation_steps else: self.ipu_config.inference_device_iterations = self.on_device_generation_steps logits_processor = self._adapt_logits_processor_for_on_device_generation(logits_processor, vocab_size) stopping_criteria = self._adapt_stopping_criteria_for_on_device_generation( stopping_criteria, self.on_device_generation_steps ) # This function only has to be called at the beginning of generation since # all necessary state should be stored in buffers on device. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # A model-agnostic version of above mainly to duplicate inputs for device iterations. model_inputs = self._prepare_inputs_for_on_device_generation( model_inputs, self.on_device_generation_steps, batch_beam_size ) per_replica_batch_size = batch_size // decoder_ipu_config.inference_replication_factor def on_device_generation_model_ctr(inst): return OnDeviceGenerationModel( inst, batch_size=per_replica_batch_size, max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, logits_processor=logits_processor, num_beams=num_beams, use_cache=True, length_penalty=beam_scorer.length_penalty, early_stopping=beam_scorer.do_early_stopping, ) generation_step = 0 while True: output = self._call_generate( generation_step=generation_step, # NB: equal to `cur_len - 1` since context_length=1 on_device_generation_model_ctr=on_device_generation_model_ctr, **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) done = torch.all( output.done.view(self.on_device_generation_steps, decoder_ipu_config.inference_replication_factor), dim=-1, ) generation_step += self.on_device_generation_steps first_done = torch.argmax(done.int()) input_ids = output.generated_tokens[ first_done * batch_size : (first_done + 1) * batch_size, : context_length + generation_step ].to(input_ids.dtype) if torch.any(done) or stopping_criteria(input_ids, ()): break return input_ids def _on_device_sample(self): raise NotImplementedError("On device sampling is not supported.") def _on_device_beam_sample(self): raise NotImplementedError("On device beam sampling is not supported.")