def get_outputs()

in pytorch_translate/ensemble_export.py [0:0]


    def get_outputs(self, src_tokens, encoder_futures):
        outputs = []
        output_names = []
        states = []

        possible_translation_tokens = None

        # underlying assumption is each model has same vocab_reduction_module
        if hasattr(self.models[0].decoder, "vocab_reduction_module"):
            vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
            if vocab_reduction_module is not None:
                possible_translation_tokens = vocab_reduction_module(
                    src_tokens=src_tokens, decoder_input_tokens=None
                )

        # Precompute reduced decoder weight matrices.
        # Once we have possible_translation_tokens, we need to gather rows
        # out of each output_projection_{w,b} tensor for the decoders to
        # use. We do it here because these reduced matrices are used on each
        # step of the beam search, and this turns out to be a relatively
        # expensive operation.
        reduced_weights = {}
        for i, model in enumerate(self.models):
            if (
                self.enable_precompute_reduced_weights
                and hasattr(model.decoder, "_precompute_reduced_weights")
                and possible_translation_tokens is not None
            ):
                reduced_weights[i] = torch.jit._fork(
                    model.decoder._precompute_reduced_weights,
                    possible_translation_tokens,
                )

        # XXX: This loop is where we wait() for each encoder's output to be
        # ready. If you're trying to add more ops, they should probably not
        # go in this loop!
        for i, (model, future) in enumerate(zip(self.models, encoder_futures)):
            encoder_out = torch.jit._wait(future)
            # "primary" encoder output (vector representations per source token)
            encoder_outputs = encoder_out[0]
            outputs.append(encoder_outputs)
            output_names.append(f"encoder_output_{i}")
            if hasattr(model.decoder, "_init_prev_states"):
                states.extend(model.decoder._init_prev_states(encoder_out))
            if (
                self.enable_precompute_reduced_weights
                and hasattr(model.decoder, "_precompute_reduced_weights")
                and possible_translation_tokens is not None
            ):
                states.extend(torch.jit._wait(reduced_weights[i]))

        if possible_translation_tokens is not None:
            outputs.append(possible_translation_tokens)
            output_names.append("possible_translation_tokens")

        for i, state in enumerate(states):
            outputs.append(state)
            output_names.append(f"initial_state_{i}")

        self.output_names = output_names

        return tuple(outputs)