in src/pixparse/utils/ocr_utils.py [0:0]
def get_next_token(next_token_logits, use_sample: bool = True, temperature: float = 5):
"""
Choose the next token given a logits distribution.
Args:
next_token_logits: The logits distribution of the next token.
use_sample: If True, samples from the distribution. If False, picks the token with the highest logit.
temperature: The temperature for softmax function when use_sample=True.
Returns:
The chosen token and the probability distribution.
"""
if use_sample:
relevant_logits = next_token_logits / temperature
probs = nn.functional.softmax(relevant_logits, dim=-1)
next_token_id = (
torch.multinomial(probs, num_samples=1).reshape(-1).unsqueeze(-1)
)
else:
next_token_id = next_token_logits.argmax(1).unsqueeze(-1)
probs = torch.ones_like(next_token_logits)
return next_token_id, probs