def forward()

in pytorch_translate/vocab_reduction.py [0:0]


    def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None):
        assert self.dst_dict.pad() == 0, (
            f"VocabReduction only works correctly when the padding ID is 0 "
            "(to ensure its position in possible_translation_tokens is also 0), "
            f"instead of {self.dst_dict.pad()}."
        )
        vocab_list = [src_tokens.new_tensor([self.dst_dict.pad()])]

        if decoder_input_tokens is not None:
            flat_decoder_input_tokens = decoder_input_tokens.view(-1)
            vocab_list.append(flat_decoder_input_tokens)

        if self.translation_candidates is not None:
            reduced_vocab = self.translation_candidates.index_select(
                dim=0, index=src_tokens.view(-1)
            ).view(-1)
            vocab_list.append(reduced_vocab)
        if (
            self.vocab_reduction_params is not None
            and self.vocab_reduction_params["num_top_words"] > 0
        ):
            top_words = torch.arange(
                self.vocab_reduction_params["num_top_words"],
                device=vocab_list[0].device,
            ).long()
            vocab_list.append(top_words)

        # Get bag of words predicted by word predictor
        if self.predictor is not None:
            assert encoder_output is not None
            pred_output = self.predictor(encoder_output)
            # [batch, k]
            topk_indices = self.predictor.get_topk_predicted_tokens(
                pred_output, src_tokens, log_probs=True
            )
            # flatten indices for entire batch [1, batch * k]
            topk_indices = topk_indices.view(-1)
            vocab_list.append(topk_indices.detach())

        all_translation_tokens = torch.cat(vocab_list, dim=0)
        possible_translation_tokens = torch.unique(
            all_translation_tokens,
            # Sorting helps ensure that the padding ID (0) remains in position 0.
            sorted=True,
            # The decoder_input_tokens used here are very close to the targets
            # tokens that we also need to map to the reduced vocab space later
            # on, except that decoder_input_tokens have <eos> prepended, while
            # the targets will have <eos> at the end of the sentence. This
            # prevents us from being able to directly use the inverse indices
            # that torch.unique can return.
            return_inverse=False,
        ).type_as(src_tokens)

        # Pad to a multiple of 8 to ensure training with fp16 will activate
        # NVIDIA Tensor Cores.
        len_mod_eight = possible_translation_tokens.shape[0] % 8
        if self.training and self.fp16 and len_mod_eight != 0:
            possible_translation_tokens = torch.cat(
                [
                    possible_translation_tokens,
                    possible_translation_tokens.new_tensor(
                        [self.dst_dict.pad()] * (8 - len_mod_eight)
                    ),
                ]
            )

        return possible_translation_tokens