optimum/neuron/generation/logits_process.py (38 lines of code) (raw):

from dataclasses import dataclass from typing import Tuple import torch from transformers import GenerationConfig @dataclass class FusedLogitsWarper: """ A class that performs top-k then top-p filtering, optionally applying a temperature. Top-k filtering only keeps the `k` tokens with the best scores. Top-p filtering only keeps the top tokens whose cumulated probability is above `p`. The filtered tokens are returned as a list of indices, along with the corresponding subset of the original logits. If only top-k filtering is active, the filtered tokens are sorted by descending order. If top-p filtering is active, the filtered tokens are sorted by ascending order. Args: temperature (`float`): Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely token. top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. """ temperature: float = 1.0 top_k: int = 0 top_p: float = 1.0 @classmethod def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper": """Instantiate a fused warper from a generation configuration. Args: generation_config (`~transformers.generation.GenerationConfig`): The generation configuration to be used as base parametrization for the fused warper. Returns: a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. """ if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0: raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.") return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: if self.temperature != 1.0: logits = logits / self.temperature do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] do_top_p = self.top_p < 1.0 and self.top_p > 0.0 if not do_top_k and not do_top_p: return logits, None if do_top_k: sorted_logits, sorted_indices = torch.topk(logits, self.top_k) else: # Warning: not applying top-k filtering leads to this very slow sort operation sorted_logits, sorted_indices = torch.sort(logits) if do_top_p: if do_top_k: # logits have been sorted in descending order, so we need to flip them sorted_logits = torch.flip(sorted_logits, [-1]) sorted_indices = torch.flip(sorted_indices, [-1]) # We always keep the best logits and those whose cumulative probability is strictly higher than top_p cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) keep_mask = cum_probs > (1 - self.top_p) keep_mask[:, -1] = True # Set rejected logits to -inf so that they are ignored in downstream comparisons sorted_logits[~keep_mask] = float("-Inf") # Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops keep_by_batch = torch.sum(keep_mask, dim=-1) keep = torch.amax(keep_by_batch) sorted_logits = sorted_logits[:, -keep:] sorted_indices = sorted_indices[:, -keep:] return sorted_logits, sorted_indices