def _topk_decode()

in custom/sequence_generator.py [0:0]


    def _topk_decode(self, logits, topk, topp):
        """WARNING!!! This can modify the `self.pad` position of `logits`."""
        if topk == 1 and topp == 0:  # greedy
            logits[:, self.pad] = -math.inf  # as in fairseq code
            pred_tok = logits.argmax(dim=1, keepdim=True)

        else:
            if topk > 1:
                logits[:, self.pad] = -1e10  # never select pad
                logits = top_k_logits(logits, topk)
                pred_tok = torch.softmax(logits, -1).multinomial(1)
            else:
                assert topp > 0.0
                filtered_probs, bookkeep_idx = self._sample_topp(torch.softmax(logits, 1), sampling_topp=topp)
                selected = filtered_probs.multinomial(1)
                pred_tok = torch.gather(bookkeep_idx, index=selected, dim=1)

        return pred_tok