def __call__()

in optimum/tpu/generation/logits_process.py [0:0]


    def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        if self.temperature != 1.0:
            logits = logits / self.temperature

        do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
        do_top_p = self.top_p < 1.0 and self.top_p > 0.0

        if do_top_k:
            sorted_logits, sorted_indices = torch.topk(logits, self.top_k)
        else:
            # Warning: not applying top-k filtering leads to this very slow sort operation
            sorted_logits, sorted_indices = torch.sort(logits)

        if do_top_p:
            if do_top_k:
                # logits have been sorted in descending order, so we need to flip them
                sorted_logits = torch.flip(sorted_logits, [-1])
                sorted_indices = torch.flip(sorted_indices, [-1])
            # We always keep the best logits and those whose cumulative probability is strictly higher than top_p
            cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
            keep_mask = cum_probs > (1 - self.top_p)
            keep_mask[:, -1] = True
            # Set rejected logits to -inf so that they are ignored in downstream comparisons
            sorted_logits[~keep_mask] = float("-Inf")
            # Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops
            keep_by_batch = torch.sum(keep_mask, dim=-1)
            keep = torch.amax(keep_by_batch)
            sorted_logits = sorted_logits[:, -keep:]
            sorted_indices = sorted_indices[:, -keep:]

        return sorted_logits, sorted_indices