def get_next_token()

in src/pixparse/utils/ocr_utils.py [0:0]


def get_next_token(next_token_logits, use_sample: bool = True, temperature: float = 5):
    """
    Choose the next token given a logits distribution.

    Args:
        next_token_logits: The logits distribution of the next token.
        use_sample: If True, samples from the distribution. If False, picks the token with the highest logit.
        temperature: The temperature for softmax function when use_sample=True.

    Returns:
        The chosen token and the probability distribution.
    """
    if use_sample:
        relevant_logits = next_token_logits / temperature
        probs = nn.functional.softmax(relevant_logits, dim=-1)

        next_token_id = (
            torch.multinomial(probs, num_samples=1).reshape(-1).unsqueeze(-1)
        )
    else:
        next_token_id = next_token_logits.argmax(1).unsqueeze(-1)
        probs = torch.ones_like(next_token_logits)
    return next_token_id, probs