def forward()

in pytorch_translate/ensemble_export.py [0:0]


    def forward(self, input_token, target_token, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        """
        log_probs_per_model = []
        state_outputs = []

        next_state_input = len(self.models)

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            encoder_output = inputs[i]
            prev_hiddens = []
            prev_cells = []

            for _ in range(len(model.decoder.layers)):
                prev_hiddens.append(inputs[next_state_input])
                prev_cells.append(inputs[next_state_input + 1])
                next_state_input += 2
            prev_input_feed = inputs[next_state_input].view(1, -1)
            next_state_input += 1

            if (
                self.enable_precompute_reduced_weights
                and hasattr(model.decoder, "_precompute_reduced_weights")
                and possible_translation_tokens is not None
            ):
                # (output_projection_w, output_projection_b)
                reduced_output_weights = inputs[next_state_input : next_state_input + 2]
                next_state_input += 2
            else:
                reduced_output_weights = None

            # no batching, we only care about care about "max" length
            src_length_int = int(encoder_output.size()[0])
            src_length = torch.LongTensor(np.array([src_length_int]))

            # notional, not actually used for decoder computation
            src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))
            src_embeddings = encoder_output.new_zeros(encoder_output.shape)

            encoder_out = (
                encoder_output,
                prev_hiddens,
                prev_cells,
                src_length,
                src_tokens,
                src_embeddings,
            )

            # store cached states, use evaluation mode
            model.decoder._is_incremental_eval = True
            model.eval()

            # placeholder
            incremental_state = {}

            # cache previous state inputs
            utils.set_incremental_state(
                model.decoder,
                incremental_state,
                "cached_state",
                (prev_hiddens, prev_cells, prev_input_feed),
            )

            decoder_output = model.decoder(
                input_token.view(1, 1),
                encoder_out,
                incremental_state=incremental_state,
                possible_translation_tokens=possible_translation_tokens,
            )
            logits, _, _ = decoder_output

            log_probs = F.log_softmax(logits, dim=2)

            log_probs_per_model.append(log_probs)

            (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
                model.decoder, incremental_state, "cached_state"
            )

            for h, c in zip(next_hiddens, next_cells):
                state_outputs.extend([h, c])
            state_outputs.append(next_input_feed)

            if reduced_output_weights is not None:
                state_outputs.extend(reduced_output_weights)

        average_log_probs = torch.mean(
            torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True
        )

        if possible_translation_tokens is not None:
            reduced_indices = torch.zeros(self.vocab_size).long().fill_(self.unk_token)
            # ONNX-exportable arange (ATen op)
            possible_translation_token_range = torch._dim_arange(
                like=possible_translation_tokens, dim=0
            )
            reduced_indices[
                possible_translation_tokens
            ] = possible_translation_token_range
            reduced_index = reduced_indices.index_select(dim=0, index=target_token)
            score = average_log_probs.view((-1,)).index_select(
                dim=0, index=reduced_index
            )
        else:
            score = average_log_probs.view((-1,)).index_select(
                dim=0, index=target_token
            )

        word_reward = self.word_rewards.index_select(0, target_token)
        score += word_reward

        self.input_names = ["prev_token", "target_token", "timestep"]
        for i in range(len(self.models)):
            self.input_names.append(f"fixed_input_{i}")

        if possible_translation_tokens is not None:
            self.input_names.append("possible_translation_tokens")

        outputs = [score]
        self.output_names = ["score"]

        for i in range(len(self.models)):
            self.output_names.append(f"fixed_input_{i}")
            outputs.append(inputs[i])

        if possible_translation_tokens is not None:
            self.output_names.append("possible_translation_tokens")
            outputs.append(possible_translation_tokens)

        for i, state in enumerate(state_outputs):
            outputs.append(state)
            self.output_names.append(f"state_output_{i}")
            self.input_names.append(f"state_input_{i}")

        return tuple(outputs)