optimum/neuron/generation/token_selector.py (132 lines of code) (raw):
import copy
import logging
from typing import TYPE_CHECKING, List, Optional
import torch
from transformers.generation import (
GenerationConfig,
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import GenerationMode
from .logits_process import FusedLogitsWarper
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = logging.getLogger(__name__)
class TokenSelector:
"""Implements the token selection logic corresponding to a generation configuration.
This class combines and uses the logits processors and stopping criterias implemented in
the transformers library.
The algorithm to select these objects is heavily inspired by the transformers `GenerationMixin.generate()`
method, but the actual token selection methods are specific.
The reason why this class does not inherit from `GenerationMixin` is because it does not
include the code to produce the tokens logits.
Separating the production of the tokens logits from the tokens selection allows this class
to be used with different generation paradigms, either synchronously using a single `TokenSelector` in
`GenerationMixin.generate()` or asynchronously using multiple `TokenSelector` inside an inference endpoint.
The constructor of this class should not be called directly: instances should be obtained by
calling `TokenSelector.create()`.
"""
def __init__(
self,
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_ids: List[int],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
self.generator = torch.Generator()
self.generator.manual_seed(seed)
@classmethod
def create(
cls,
input_ids: torch.Tensor,
generation_config: GenerationConfig,
model: GenerationMixin,
max_seq_length: int,
stopping_criteria: Optional[StoppingCriteriaList] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
seed: Optional[int] = 0,
) -> "TokenSelector":
r"""Creates the `TokenSelector` for a specific generation configuration.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
The generation configuration to parametrize the token selection.
model (`~transformers.generation.GenerationMixin`):
The model provides the internal helpers allowing to select the logits processors and stopping criterias.
max_seq_length (`int`):
The maximum number of input + generated tokens for this model. It depends on the model compilation parameters.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config
tokenizer (`Optional[transformers.PreTrainedTokenizer]`, default to `None`):
A tokenizer used when stop strings are passed to generate.
seed(`Optional[int]`):
The optional seed for sampling. Defaults to zero.
Return:
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
generation_config.validate()
generation_config = copy.deepcopy(generation_config)
model._prepare_special_tokens(generation_config)
unsupported_generation_flags = [
"output_attentions",
"output_hidden_states",
"output_scores",
"return_dict_in_generate",
]
for flag in unsupported_generation_flags:
if getattr(generation_config, flag, False):
raise ValueError("{flag} is not supported for generation.")
if generation_config.max_new_tokens is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[-1]
min_length = generation_config.min_length
if min_length > max_seq_length:
raise ValueError(
f"The minimum generation length ({min_length}) exceeds the model maximum sequence length ({max_seq_length})"
)
max_length = generation_config.max_length
if max_length > max_seq_length:
logger.warning(
f"Adjusting the maximum generation length ({max_length}) to the model maximum sequence length ({max_seq_length})"
)
generation_config.max_length = max_seq_length
# Instantiate transformers library processors and criterias
logits_processor = model._get_logits_processor(
generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
)
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(
generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer
)
# This is not supposed to happen for any of the models we support
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
if generation_config.pad_token_id is None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-ended generation.")
generation_config.pad_token_id = eos_token_ids[0]
generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
raise ValueError("Unsupported generation mode")
logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
# Remove transformers TopK, TopP and Temperature processors
logits_processor = LogitsProcessorList(
[
p
for p in logits_processor
if not isinstance(p, (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper))
]
)
# We use a fused logits warper instead
logits_warper = FusedLogitsWarper.from_config(generation_config)
return cls(
mode=generation_mode,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)
def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.LongTensor:
"""Select the next tokens from the candidate logits.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation (not used in all generation modes).
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The logits corresponding to the generated tokens.
Return:
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
scores = self.logits_processor(input_ids, logits)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
else:
return torch.argmax(scores, dim=-1)
def _sample(self, scores: torch.Tensor) -> torch.LongTensor:
# Get [batch_size, kept] scores and indices instead of [batch_size, vocab_size] scores
scores, next_token_indices = self.logits_warper(scores)
# sample
probs = torch.nn.functional.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
# Convert the filtered tokens to actual vocabulary tokens
next_tokens = torch.gather(next_token_indices, 1, next_tokens)
return next_tokens.squeeze(1)