in pytorch_translate/ensemble_export.py [0:0]
def forward(self, input_token, target_token, timestep, *inputs):
"""
Decoder step inputs correspond one-to-one to encoder outputs.
"""
log_probs_per_model = []
state_outputs = []
next_state_input = len(self.models)
# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
if vocab_reduction_module is not None:
possible_translation_tokens = inputs[len(self.models)]
next_state_input += 1
else:
possible_translation_tokens = None
for i, model in enumerate(self.models):
encoder_output = inputs[i]
prev_hiddens = []
prev_cells = []
for _ in range(len(model.decoder.layers)):
prev_hiddens.append(inputs[next_state_input])
prev_cells.append(inputs[next_state_input + 1])
next_state_input += 2
prev_input_feed = inputs[next_state_input].view(1, -1)
next_state_input += 1
if (
self.enable_precompute_reduced_weights
and hasattr(model.decoder, "_precompute_reduced_weights")
and possible_translation_tokens is not None
):
# (output_projection_w, output_projection_b)
reduced_output_weights = inputs[next_state_input : next_state_input + 2]
next_state_input += 2
else:
reduced_output_weights = None
# no batching, we only care about care about "max" length
src_length_int = int(encoder_output.size()[0])
src_length = torch.LongTensor(np.array([src_length_int]))
# notional, not actually used for decoder computation
src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))
src_embeddings = encoder_output.new_zeros(encoder_output.shape)
encoder_out = (
encoder_output,
prev_hiddens,
prev_cells,
src_length,
src_tokens,
src_embeddings,
)
# store cached states, use evaluation mode
model.decoder._is_incremental_eval = True
model.eval()
# placeholder
incremental_state = {}
# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)
decoder_output = model.decoder(
input_token.view(1, 1),
encoder_out,
incremental_state=incremental_state,
possible_translation_tokens=possible_translation_tokens,
)
logits, _, _ = decoder_output
log_probs = F.log_softmax(logits, dim=2)
log_probs_per_model.append(log_probs)
(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)
for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
state_outputs.append(next_input_feed)
if reduced_output_weights is not None:
state_outputs.extend(reduced_output_weights)
average_log_probs = torch.mean(
torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True
)
if possible_translation_tokens is not None:
reduced_indices = torch.zeros(self.vocab_size).long().fill_(self.unk_token)
# ONNX-exportable arange (ATen op)
possible_translation_token_range = torch._dim_arange(
like=possible_translation_tokens, dim=0
)
reduced_indices[
possible_translation_tokens
] = possible_translation_token_range
reduced_index = reduced_indices.index_select(dim=0, index=target_token)
score = average_log_probs.view((-1,)).index_select(
dim=0, index=reduced_index
)
else:
score = average_log_probs.view((-1,)).index_select(
dim=0, index=target_token
)
word_reward = self.word_rewards.index_select(0, target_token)
score += word_reward
self.input_names = ["prev_token", "target_token", "timestep"]
for i in range(len(self.models)):
self.input_names.append(f"fixed_input_{i}")
if possible_translation_tokens is not None:
self.input_names.append("possible_translation_tokens")
outputs = [score]
self.output_names = ["score"]
for i in range(len(self.models)):
self.output_names.append(f"fixed_input_{i}")
outputs.append(inputs[i])
if possible_translation_tokens is not None:
self.output_names.append("possible_translation_tokens")
outputs.append(possible_translation_tokens)
for i, state in enumerate(state_outputs):
outputs.append(state)
self.output_names.append(f"state_output_{i}")
self.input_names.append(f"state_input_{i}")
return tuple(outputs)