in pytorch_translate/ensemble_export.py [0:0]
def get_outputs(self, src_tokens, encoder_futures):
outputs = []
output_names = []
states = []
possible_translation_tokens = None
# underlying assumption is each model has same vocab_reduction_module
if hasattr(self.models[0].decoder, "vocab_reduction_module"):
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
if vocab_reduction_module is not None:
possible_translation_tokens = vocab_reduction_module(
src_tokens=src_tokens, decoder_input_tokens=None
)
# Precompute reduced decoder weight matrices.
# Once we have possible_translation_tokens, we need to gather rows
# out of each output_projection_{w,b} tensor for the decoders to
# use. We do it here because these reduced matrices are used on each
# step of the beam search, and this turns out to be a relatively
# expensive operation.
reduced_weights = {}
for i, model in enumerate(self.models):
if (
self.enable_precompute_reduced_weights
and hasattr(model.decoder, "_precompute_reduced_weights")
and possible_translation_tokens is not None
):
reduced_weights[i] = torch.jit._fork(
model.decoder._precompute_reduced_weights,
possible_translation_tokens,
)
# XXX: This loop is where we wait() for each encoder's output to be
# ready. If you're trying to add more ops, they should probably not
# go in this loop!
for i, (model, future) in enumerate(zip(self.models, encoder_futures)):
encoder_out = torch.jit._wait(future)
# "primary" encoder output (vector representations per source token)
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
output_names.append(f"encoder_output_{i}")
if hasattr(model.decoder, "_init_prev_states"):
states.extend(model.decoder._init_prev_states(encoder_out))
if (
self.enable_precompute_reduced_weights
and hasattr(model.decoder, "_precompute_reduced_weights")
and possible_translation_tokens is not None
):
states.extend(torch.jit._wait(reduced_weights[i]))
if possible_translation_tokens is not None:
outputs.append(possible_translation_tokens)
output_names.append("possible_translation_tokens")
for i, state in enumerate(states):
outputs.append(state)
output_names.append(f"initial_state_{i}")
self.output_names = output_names
return tuple(outputs)