in models/utils.py [0:0]
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
"""
Apply top-k and/or nucleus (top-p) filtering to logits.
"""
top_k = min(top_k, logits.size(-1)) # Safety
if top_k > 0:
# Remove all tokens with a probability less than the top-k tokens
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, filter_value)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative probability above top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Always keep the first token
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, filter_value)
return logits