def generate()

in pytorch_translate/ensemble_export.py [0:0]


    def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):

        # TODO: model ensemble
        assert len(models) == 1, "only support single model"
        model = models[0]
        bsz, src_len = src_tokens.size()
        sent_idxs = torch.arange(bsz)
        # encoding
        encoder_out = model.encoder(src_tokens, src_lengths)

        # initialize buffers (very model specific, with length prediction or not)
        prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
        prev_output_tokens = prev_decoder_out.output_tokens.clone()

        finalized_tokens_list = [torch.tensor(0) for _ in range(bsz)]
        finalized_scores_list = [torch.tensor(0) for _ in range(bsz)]
        finalized_attns_list = [torch.tensor(0) for _ in range(bsz)]
        finalized_alignments_list = [torch.tensor(0) for _ in range(bsz)]

        for step in range(self.max_iter + 1):
            prev_decoder_out = prev_decoder_out._replace(
                step=step, max_step=self.max_iter + 1
            )
            decoder_out = model.forward_decoder(
                prev_decoder_out,
                encoder_out,
                eos_penalty=self.eos_penalty,
                max_ratio=self.max_ratio,
                decoding_format=self.decoding_format,
            )
            terminated, output_tokens, output_scores, output_attn = is_a_loop(
                self.pad,
                prev_output_tokens,
                decoder_out.output_tokens,
                decoder_out.output_scores,
                decoder_out.attn,
            )
            decoder_out = decoder_out._replace(
                output_tokens=output_tokens,
                output_scores=output_scores,
                attn=output_attn,
            )

            terminated = last_step(step, self.max_iter, terminated)
            # collect finalized sentences
            finalized_idxs = sent_idxs[terminated]
            finalized_tokens = decoder_out.output_tokens[terminated]
            finalized_scores = decoder_out.output_scores[terminated]
            finalized_attn = (
                None if decoder_out.attn is None else decoder_out.attn[terminated]
            )
            finalized_tokens_list = finalize_hypos_loop_tokens(
                finalized_tokens_list,
                finalized_idxs,
                self.pad,
                finalized_tokens,
                finalized_scores,
            )
            finalized_scores_list = finalize_hypos_loop_scores(
                finalized_scores_list,
                finalized_idxs,
                self.pad,
                finalized_tokens,
                finalized_scores,
            )
            finalized_attns_list, finalized_alignments_list = finalize_hypos_loop_attns(
                finalized_attns_list,
                finalized_alignments_list,
                finalized_idxs,
                self.pad,
                finalized_tokens,
                finalized_scores,
                finalized_attn,
            )

            # for next step
            not_terminated = ~terminated
            prev_decoder_out = decoder_out._replace(
                output_tokens=script_skip_tensor(
                    decoder_out.output_tokens, not_terminated
                ),
                output_scores=script_skip_tensor(
                    decoder_out.output_scores, not_terminated
                ),
                attn=decoder_out.attn,
                step=decoder_out.step,
                max_step=decoder_out.max_step,
            )
            encoder_out = EncoderOut(
                encoder_out=script_skip_tensor(encoder_out.encoder_out, ~terminated),
                encoder_padding_mask=None,
                encoder_embedding=script_skip_tensor(
                    encoder_out.encoder_embedding, ~terminated
                ),
                encoder_states=None,
                src_tokens=None,
                src_lengths=None,
            )
            sent_idxs = script_skip_tensor(sent_idxs, not_terminated)

            prev_output_tokens = prev_decoder_out.output_tokens.clone()

        return (
            finalized_tokens_list,
            finalized_scores_list,
            finalized_attns_list,
            finalized_alignments_list,
        )