in pytorch_translate/vocab_reduction.py [0:0]
def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None):
assert self.dst_dict.pad() == 0, (
f"VocabReduction only works correctly when the padding ID is 0 "
"(to ensure its position in possible_translation_tokens is also 0), "
f"instead of {self.dst_dict.pad()}."
)
vocab_list = [src_tokens.new_tensor([self.dst_dict.pad()])]
if decoder_input_tokens is not None:
flat_decoder_input_tokens = decoder_input_tokens.view(-1)
vocab_list.append(flat_decoder_input_tokens)
if self.translation_candidates is not None:
reduced_vocab = self.translation_candidates.index_select(
dim=0, index=src_tokens.view(-1)
).view(-1)
vocab_list.append(reduced_vocab)
if (
self.vocab_reduction_params is not None
and self.vocab_reduction_params["num_top_words"] > 0
):
top_words = torch.arange(
self.vocab_reduction_params["num_top_words"],
device=vocab_list[0].device,
).long()
vocab_list.append(top_words)
# Get bag of words predicted by word predictor
if self.predictor is not None:
assert encoder_output is not None
pred_output = self.predictor(encoder_output)
# [batch, k]
topk_indices = self.predictor.get_topk_predicted_tokens(
pred_output, src_tokens, log_probs=True
)
# flatten indices for entire batch [1, batch * k]
topk_indices = topk_indices.view(-1)
vocab_list.append(topk_indices.detach())
all_translation_tokens = torch.cat(vocab_list, dim=0)
possible_translation_tokens = torch.unique(
all_translation_tokens,
# Sorting helps ensure that the padding ID (0) remains in position 0.
sorted=True,
# The decoder_input_tokens used here are very close to the targets
# tokens that we also need to map to the reduced vocab space later
# on, except that decoder_input_tokens have <eos> prepended, while
# the targets will have <eos> at the end of the sentence. This
# prevents us from being able to directly use the inverse indices
# that torch.unique can return.
return_inverse=False,
).type_as(src_tokens)
# Pad to a multiple of 8 to ensure training with fp16 will activate
# NVIDIA Tensor Cores.
len_mod_eight = possible_translation_tokens.shape[0] % 8
if self.training and self.fp16 and len_mod_eight != 0:
possible_translation_tokens = torch.cat(
[
possible_translation_tokens,
possible_translation_tokens.new_tensor(
[self.dst_dict.pad()] * (8 - len_mod_eight)
),
]
)
return possible_translation_tokens