in custom/sequence_generator.py [0:0]
def _topk_decode(self, logits, topk, topp):
"""WARNING!!! This can modify the `self.pad` position of `logits`."""
if topk == 1 and topp == 0: # greedy
logits[:, self.pad] = -math.inf # as in fairseq code
pred_tok = logits.argmax(dim=1, keepdim=True)
else:
if topk > 1:
logits[:, self.pad] = -1e10 # never select pad
logits = top_k_logits(logits, topk)
pred_tok = torch.softmax(logits, -1).multinomial(1)
else:
assert topp > 0.0
filtered_probs, bookkeep_idx = self._sample_topp(torch.softmax(logits, 1), sampling_topp=topp)
selected = filtered_probs.multinomial(1)
pred_tok = torch.gather(bookkeep_idx, index=selected, dim=1)
return pred_tok