def generate_beam()

in XLM/src/model/transformer.py [0:0]


    def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """
        if isinstance(max_len, int):
            max_lengths = src_len.clone().fill_(max_len)
            global_max_len = max_len
        else:
            max_lengths = max_len
            global_max_len = int(max_lengths.max())

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_len)
        n_words = self.n_words

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(1).expand(
            (bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(
            bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(global_max_len, bs *
                                beam_size)  # upcoming output
        # fill upcoming ouput with <PAD>
        generated.fill_(self.pad_index)
        # we use <EOS> for <BOS> everywhere
        generated[0].fill_(self.eos_index)

        # generated hypotheses
        generated_hyps = [BeamHypotheses(
            beam_size, global_max_len, length_penalty, early_stopping) for _ in range(bs)]

        # positions
        positions = src_len.new(global_max_len).long()
        positions = torch.arange(global_max_len, out=positions).unsqueeze(
            1).expand_as(generated)

        # language IDs
        langs = positions.clone().fill_(tgt_lang_id)

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).float().fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        self.cache = {'slen': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < global_max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=src_len.new(bs * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                use_cache=True
            )
            assert tensor.size() == (1, bs * beam_size, self.dim)
            # (bs * beam_size, dim)
            tensor = tensor.data[-1, :, :].type_as(src_enc)
            scores = self.pred_layer.get_scores(
                tensor)     # (bs * beam_size, n_words)
            # (bs * beam_size, n_words)
            scores = F.log_softmax(scores.float(), dim=-1)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            # (bs * beam_size, n_words)
            _scores = scores + beam_scores[:, None].expand_as(scores)
            # (bs, beam_size * n_words)
            _scores = _scores.view(bs, beam_size * n_words)

            next_scores, next_words = torch.topk(
                _scores, 2 * beam_size, dim=1, largest=True, sorted=True)
            assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(
                    next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend(
                        [(0, self.pad_index, 0)] * beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == global_max_len:
                        generated_hyps[sent_id].add(
                            generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
                    else:
                        next_sent_beam.append(
                            (value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam) == 0 if cur_len + \
                    1 == global_max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)] * \
                        beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in self.cache.keys():
                if k != 'slen':
                    self.cache[k] = (self.cache[k][0][beam_idx],
                                     self.cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # visualize hypotheses
        # print([len(x) for x in generated_hyps], cur_len)
        # globals().update( locals() );
        # !import code; code.interact(local=vars())
        # for ii in range(bs):
        #     for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
        #         print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
        #     print("")

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            sorted_hyps = [h[1] for h in sorted(
                hypotheses.hyp, key=lambda x: x[0], reverse=True)]
            tgt_len[i] = max([len(hyp) for hyp in sorted_hyps]
                             ) + 1  # +1 for the <EOS> symbol
            best.append(sorted_hyps)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(),
                              beam_size, bs).fill_(self.pad_index)
        for i, hypo_list in enumerate(best):
            for hyp_index, hypo in enumerate(hypo_list):
                decoded[:len(hypo), hyp_index, i] = hypo
                decoded[len(hypo), hyp_index, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == 2 * beam_size * bs

        return decoded, tgt_len