def forward()

in sockeye/beam_search_pt.py [0:0]


    def forward(self,
                source: pt.Tensor,
                source_length: pt.Tensor,
                restrict_lexicon: Optional[lexicon.TopKLexicon],
                max_output_lengths: pt.Tensor) -> SearchResult:
        """
        Translates multiple sentences using beam search.

        :param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
        :param source_length: Valid source lengths. Shape: (batch_size,).
        :param restrict_lexicon: Lexicon to use for vocabulary restriction.
        :param max_output_lengths: Tensor of maximum output lengths per input in source.
                Shape: (batch_size,). Dtype: int32.
        :return SearchResult.
        """
        batch_size = source.size()[0]
        logger.debug("beam_search batch size: %d", batch_size)

        # Maximum beam search iterations (determined by longest input with eos)
        max_iterations = int(max_output_lengths.max().item())
        logger.debug("max beam search iterations: %d", max_iterations)

        if self._sample is not None:
            utils.check_condition(restrict_lexicon is None, "restricted lexicon not available when sampling.")

        # General data structure: batch_size * beam_size blocks in total;
        # a full beam for each sentence, followed by the next beam-block for the next sentence and so on

        # best word_indices (also act as input: (batch*beam, num_target_factors
        best_word_indices = pt.full((batch_size * self.beam_size, self.num_target_factors),
                                    fill_value=self.bos_id, device=self.device, dtype=pt.int32)

        # offset for hypothesis indices in batch decoding
        offset = pt.arange(0, batch_size * self.beam_size, self.beam_size,
                           dtype=pt.int32, device=self.device).repeat_interleave(self.beam_size)

        # locations of each batch item when first dimension is (batch * beam)
        batch_indices = pt.arange(0, batch_size * self.beam_size, self.beam_size, dtype=pt.int64, device=self.device)
        first_step_mask = pt.full((batch_size * self.beam_size, 1), fill_value=onp.inf, device=self.device, dtype=self.dtype)
        first_step_mask[batch_indices] = 0.0

        # Best word and hypotheses indices across beam search steps from topk operation.
        best_hyp_indices_list = []  # type: List[pt.Tensor]
        best_word_indices_list = []  # type: List[pt.Tensor]

        lengths = pt.zeros(batch_size * self.beam_size, device=self.device, dtype=pt.int32)
        finished = pt.zeros(batch_size * self.beam_size, device=self.device, dtype=pt.bool)

        # Extending max_output_lengths to shape (batch_size * beam_size,)
        max_output_lengths = max_output_lengths.repeat_interleave(self.beam_size, dim=0)

        # scores_accumulated: chosen smallest scores in scores (ascending).
        scores_accumulated = pt.zeros(batch_size * self.beam_size, 1, device=self.device, dtype=self.dtype)
        # Accumulated (greedily chosen) factor scores. Factor scores are not normalized by length.
        # TODO: Consider joint tensor for all target factors
        # Embedded in a list to efficiently assign return values and avoid if-branching
        factor_scores_accumulated = [pt.zeros(batch_size * self.beam_size, self.num_target_factors - 1,
                                              device=self.device, dtype=self.dtype)]

        output_vocab_size = self.output_vocab_size

        # If using a top-k lexicon, select param rows for logit computation that correspond to the
        # target vocab for this sentence.
        vocab_slice_ids = None  # type: Optional[pt.Tensor]
        if restrict_lexicon:
            source_words = source[:, :, 0]
            vocab_slice_ids, output_vocab_size = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id,
                                                                      beam_size=1)

        pad_dist = pt.full((1, output_vocab_size), fill_value=onp.inf, device=self.device, dtype=self.dtype)
        pad_dist[0, 0] = 0  # [0, inf, inf, ...]
        eos_dist = pt.full((1, output_vocab_size),
                           fill_value=onp.inf, device=self.device, dtype=self.dtype)
        eos_dist[:, C.EOS_ID] = 0

        # (0) encode source sentence, returns a list
        model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length)
        # repeat states to beam_size
        if self._traced_repeat_states is None:
            logger.debug("Tracing repeat_states")
            self._traced_repeat_states = pt.jit.trace(self._repeat_states, model_states, strict=False)
        model_states = self._traced_repeat_states(*model_states)
        # repeat estimated_reference_lengths to shape (batch_size * beam_size)
        estimated_reference_lengths = estimated_reference_lengths.repeat_interleave(self.beam_size, dim=0)

        t = 1
        for t in range(1, max_iterations + 1):  # max_iterations + 1 required to get correct results
            # (1) obtain next predictions and advance models' state
            # target_dists: (batch_size * beam_size, target_vocab_size)
            # target_factors: (batch_size * beam_size, num_secondary_factors, 2),
            # where last dimension holds indices and scores
            target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices,
                                                                                     model_states,
                                                                                     vocab_slice_ids)

            # (2) Produces the accumulated cost of target words in each row.
            # There is special treatment for finished rows.
            # Finished rows are inf everywhere except column zero, which holds the accumulated model score
            scores, lengths = self._update_scores(target_dists, finished, scores_accumulated,
                                                  lengths, max_output_lengths, pad_dist, eos_dist)

            # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
            # far as the active beam size for each sentence.
            if self._sample is not None:
                best_hyp_indices, best_word_indices, scores_accumulated = self._sample(scores, target_dists, finished)
            else:
                # On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions
                # of the first row only by setting all other rows to inf
                if t == 1:
                    scores += first_step_mask

                if self._traced_top is None:
                    logger.debug("Tracing _top")
                    self._traced_top = pt.jit.trace(self._top, (scores, offset))
                best_hyp_indices, best_word_indices, scores_accumulated = self._traced_top(scores, offset)

            # Map from restricted to full vocab ids if needed
            if restrict_lexicon:
                best_word_indices = vocab_slice_ids.index_select(0, best_word_indices)

            # (4) Normalize the scores of newly finished hypotheses. Note that after this until the
            # next call to topk(), hypotheses may not be in sorted order.
            _sort_inputs = [best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths,
                            estimated_reference_lengths]
            if self.num_target_factors > 1:
                _sort_inputs += [target_factors, *factor_scores_accumulated]
            if self._traced_sort_norm_and_update_finished is None:
                self._traced_sort_norm_and_update_finished = pt.jit.trace(self._sort_norm_and_update_finished,
                                                                          _sort_inputs)
            best_word_indices, finished, \
            (scores_accumulated, *factor_scores_accumulated), \
            lengths, estimated_reference_lengths = self._traced_sort_norm_and_update_finished(*_sort_inputs)

            # Collect best hypotheses, best word indices
            best_word_indices_list.append(best_word_indices)
            best_hyp_indices_list.append(best_hyp_indices)

            if self._should_stop(finished, batch_size):
                break

            # (5) update models' state with winning hypotheses (ascending)
            if self._traced_sort_states is None:
                logger.debug("Tracing sort_states")
                self._traced_sort_states = pt.jit.trace(self._sort_states, (best_hyp_indices, *model_states))
            model_states = self._traced_sort_states(best_hyp_indices, *model_states)

        logger.debug("Finished after %d out of %d steps.", t, max_iterations)

        # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them).
        folded_accumulated_scores = scores_accumulated.reshape(batch_size, self.beam_size)
        indices = folded_accumulated_scores.argsort(dim=1, descending=False).reshape(-1)
        # 1 = scores_accumulated.size()[1]
        best_hyp_indices = indices.div(1, rounding_mode='floor').int() + offset
        scores_accumulated = scores_accumulated.index_select(0, best_hyp_indices)
        if self.num_target_factors > 1:
            accumulated_factor_scores = factor_scores_accumulated[0].index_select(0, best_hyp_indices)
            # (batch*beam, num_target_factors)
            scores_accumulated = pt.cat((scores_accumulated, accumulated_factor_scores), dim=1)

        best_hyp_indices_list.append(best_hyp_indices)
        lengths = lengths.index_select(0, best_hyp_indices)
        all_best_hyp_indices = pt.stack(best_hyp_indices_list, dim=1)
        all_best_word_indices = pt.stack(best_word_indices_list, dim=2)

        return SearchResult(best_hyp_indices=all_best_hyp_indices,
                            best_word_indices=all_best_word_indices,
                            accumulated_scores=scores_accumulated,
                            lengths=lengths,
                            estimated_reference_lengths=estimated_reference_lengths)