in optimum/tpu/generation/logits_process.py [0:0]
def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
if self.temperature != 1.0:
logits = logits / self.temperature
do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
do_top_p = self.top_p < 1.0 and self.top_p > 0.0
if do_top_k:
sorted_logits, sorted_indices = torch.topk(logits, self.top_k)
else:
# Warning: not applying top-k filtering leads to this very slow sort operation
sorted_logits, sorted_indices = torch.sort(logits)
if do_top_p:
if do_top_k:
# logits have been sorted in descending order, so we need to flip them
sorted_logits = torch.flip(sorted_logits, [-1])
sorted_indices = torch.flip(sorted_indices, [-1])
# We always keep the best logits and those whose cumulative probability is strictly higher than top_p
cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
keep_mask = cum_probs > (1 - self.top_p)
keep_mask[:, -1] = True
# Set rejected logits to -inf so that they are ignored in downstream comparisons
sorted_logits[~keep_mask] = float("-Inf")
# Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops
keep_by_batch = torch.sum(keep_mask, dim=-1)
keep = torch.amax(keep_by_batch)
sorted_logits = sorted_logits[:, -keep:]
sorted_indices = sorted_indices[:, -keep:]
return sorted_logits, sorted_indices