in sockeye/beam_search_pt.py [0:0]
def forward(self,
source: pt.Tensor,
source_length: pt.Tensor,
restrict_lexicon: Optional[lexicon.TopKLexicon],
max_output_lengths: pt.Tensor) -> SearchResult:
"""
Translates multiple sentences using beam search.
:param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
:param source_length: Valid source lengths. Shape: (batch_size,).
:param restrict_lexicon: Lexicon to use for vocabulary restriction.
:param max_output_lengths: Tensor of maximum output lengths per input in source.
Shape: (batch_size,). Dtype: int32.
:return SearchResult.
"""
batch_size = source.size()[0]
logger.debug("beam_search batch size: %d", batch_size)
# Maximum beam search iterations (determined by longest input with eos)
max_iterations = int(max_output_lengths.max().item())
logger.debug("max beam search iterations: %d", max_iterations)
if self._sample is not None:
utils.check_condition(restrict_lexicon is None, "restricted lexicon not available when sampling.")
# General data structure: batch_size * beam_size blocks in total;
# a full beam for each sentence, followed by the next beam-block for the next sentence and so on
# best word_indices (also act as input: (batch*beam, num_target_factors
best_word_indices = pt.full((batch_size * self.beam_size, self.num_target_factors),
fill_value=self.bos_id, device=self.device, dtype=pt.int32)
# offset for hypothesis indices in batch decoding
offset = pt.arange(0, batch_size * self.beam_size, self.beam_size,
dtype=pt.int32, device=self.device).repeat_interleave(self.beam_size)
# locations of each batch item when first dimension is (batch * beam)
batch_indices = pt.arange(0, batch_size * self.beam_size, self.beam_size, dtype=pt.int64, device=self.device)
first_step_mask = pt.full((batch_size * self.beam_size, 1), fill_value=onp.inf, device=self.device, dtype=self.dtype)
first_step_mask[batch_indices] = 0.0
# Best word and hypotheses indices across beam search steps from topk operation.
best_hyp_indices_list = [] # type: List[pt.Tensor]
best_word_indices_list = [] # type: List[pt.Tensor]
lengths = pt.zeros(batch_size * self.beam_size, device=self.device, dtype=pt.int32)
finished = pt.zeros(batch_size * self.beam_size, device=self.device, dtype=pt.bool)
# Extending max_output_lengths to shape (batch_size * beam_size,)
max_output_lengths = max_output_lengths.repeat_interleave(self.beam_size, dim=0)
# scores_accumulated: chosen smallest scores in scores (ascending).
scores_accumulated = pt.zeros(batch_size * self.beam_size, 1, device=self.device, dtype=self.dtype)
# Accumulated (greedily chosen) factor scores. Factor scores are not normalized by length.
# TODO: Consider joint tensor for all target factors
# Embedded in a list to efficiently assign return values and avoid if-branching
factor_scores_accumulated = [pt.zeros(batch_size * self.beam_size, self.num_target_factors - 1,
device=self.device, dtype=self.dtype)]
output_vocab_size = self.output_vocab_size
# If using a top-k lexicon, select param rows for logit computation that correspond to the
# target vocab for this sentence.
vocab_slice_ids = None # type: Optional[pt.Tensor]
if restrict_lexicon:
source_words = source[:, :, 0]
vocab_slice_ids, output_vocab_size = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id,
beam_size=1)
pad_dist = pt.full((1, output_vocab_size), fill_value=onp.inf, device=self.device, dtype=self.dtype)
pad_dist[0, 0] = 0 # [0, inf, inf, ...]
eos_dist = pt.full((1, output_vocab_size),
fill_value=onp.inf, device=self.device, dtype=self.dtype)
eos_dist[:, C.EOS_ID] = 0
# (0) encode source sentence, returns a list
model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length)
# repeat states to beam_size
if self._traced_repeat_states is None:
logger.debug("Tracing repeat_states")
self._traced_repeat_states = pt.jit.trace(self._repeat_states, model_states, strict=False)
model_states = self._traced_repeat_states(*model_states)
# repeat estimated_reference_lengths to shape (batch_size * beam_size)
estimated_reference_lengths = estimated_reference_lengths.repeat_interleave(self.beam_size, dim=0)
t = 1
for t in range(1, max_iterations + 1): # max_iterations + 1 required to get correct results
# (1) obtain next predictions and advance models' state
# target_dists: (batch_size * beam_size, target_vocab_size)
# target_factors: (batch_size * beam_size, num_secondary_factors, 2),
# where last dimension holds indices and scores
target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices,
model_states,
vocab_slice_ids)
# (2) Produces the accumulated cost of target words in each row.
# There is special treatment for finished rows.
# Finished rows are inf everywhere except column zero, which holds the accumulated model score
scores, lengths = self._update_scores(target_dists, finished, scores_accumulated,
lengths, max_output_lengths, pad_dist, eos_dist)
# (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
# far as the active beam size for each sentence.
if self._sample is not None:
best_hyp_indices, best_word_indices, scores_accumulated = self._sample(scores, target_dists, finished)
else:
# On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions
# of the first row only by setting all other rows to inf
if t == 1:
scores += first_step_mask
if self._traced_top is None:
logger.debug("Tracing _top")
self._traced_top = pt.jit.trace(self._top, (scores, offset))
best_hyp_indices, best_word_indices, scores_accumulated = self._traced_top(scores, offset)
# Map from restricted to full vocab ids if needed
if restrict_lexicon:
best_word_indices = vocab_slice_ids.index_select(0, best_word_indices)
# (4) Normalize the scores of newly finished hypotheses. Note that after this until the
# next call to topk(), hypotheses may not be in sorted order.
_sort_inputs = [best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths,
estimated_reference_lengths]
if self.num_target_factors > 1:
_sort_inputs += [target_factors, *factor_scores_accumulated]
if self._traced_sort_norm_and_update_finished is None:
self._traced_sort_norm_and_update_finished = pt.jit.trace(self._sort_norm_and_update_finished,
_sort_inputs)
best_word_indices, finished, \
(scores_accumulated, *factor_scores_accumulated), \
lengths, estimated_reference_lengths = self._traced_sort_norm_and_update_finished(*_sort_inputs)
# Collect best hypotheses, best word indices
best_word_indices_list.append(best_word_indices)
best_hyp_indices_list.append(best_hyp_indices)
if self._should_stop(finished, batch_size):
break
# (5) update models' state with winning hypotheses (ascending)
if self._traced_sort_states is None:
logger.debug("Tracing sort_states")
self._traced_sort_states = pt.jit.trace(self._sort_states, (best_hyp_indices, *model_states))
model_states = self._traced_sort_states(best_hyp_indices, *model_states)
logger.debug("Finished after %d out of %d steps.", t, max_iterations)
# (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them).
folded_accumulated_scores = scores_accumulated.reshape(batch_size, self.beam_size)
indices = folded_accumulated_scores.argsort(dim=1, descending=False).reshape(-1)
# 1 = scores_accumulated.size()[1]
best_hyp_indices = indices.div(1, rounding_mode='floor').int() + offset
scores_accumulated = scores_accumulated.index_select(0, best_hyp_indices)
if self.num_target_factors > 1:
accumulated_factor_scores = factor_scores_accumulated[0].index_select(0, best_hyp_indices)
# (batch*beam, num_target_factors)
scores_accumulated = pt.cat((scores_accumulated, accumulated_factor_scores), dim=1)
best_hyp_indices_list.append(best_hyp_indices)
lengths = lengths.index_select(0, best_hyp_indices)
all_best_hyp_indices = pt.stack(best_hyp_indices_list, dim=1)
all_best_word_indices = pt.stack(best_word_indices_list, dim=2)
return SearchResult(best_hyp_indices=all_best_hyp_indices,
best_word_indices=all_best_word_indices,
accumulated_scores=scores_accumulated,
lengths=lengths,
estimated_reference_lengths=estimated_reference_lengths)