in inference/generate.py [0:0]
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)