in sockeye/beam_search.py [0:0]
def forward(self,
source: np.ndarray,
source_length: np.ndarray,
restrict_lexicon: Optional[lexicon.TopKLexicon],
raw_constraint_list: List[Optional[constrained.RawConstraintList]],
raw_avoid_list: List[Optional[constrained.RawConstraintList]],
max_output_lengths: np.ndarray) -> Tuple[np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
List[Optional[np.ndarray]],
List[Optional[constrained.ConstrainedHypothesis]]]:
"""
Translates multiple sentences using beam search.
:param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
:param source_length: Valid source lengths. Shape: (batch_size,).
:param restrict_lexicon: Lexicon to use for vocabulary restriction.
:param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
that must appear in each output.
:param raw_avoid_list: A list of optional lists containing phrases (as lists of target word IDs)
that must NOT appear in each output.
:param max_output_lengths: ndarray of maximum output lengths per input in source.
Shape: (batch_size,). Dtype: int32.
:return List of best hypotheses indices, list of best word indices,
array of accumulated length-normalized negative log-probs, hypotheses lengths,
predicted lengths of references (if any), constraints (if any).
"""
batch_size = source.shape[0]
logger.debug("beam_search batch size: %d", batch_size)
# Maximum beam search iterations (determined by longest input with eos)
max_iterations = max_output_lengths.max().item()
logger.debug("max beam search iterations: %d", max_iterations)
sample_best_hyp_indices = None
if self._sample is not None:
utils.check_condition(restrict_lexicon is None,
"Sampling is not available when working with a restricted lexicon.")
sample_best_hyp_indices = np.arange(0, batch_size * self.beam_size, dtype='int32', ctx=self.context)
# General data structure: batch_size * beam_size blocks in total;
# a full beam for each sentence, followed by the next beam-block for the next sentence and so on
# best word_indices (also act as input: (batch*beam, num_target_factors
best_word_indices = np.full((batch_size * self.beam_size, self.num_target_factors),
fill_value=self.bos_id, ctx=self.context, dtype='int32')
# offset for hypothesis indices in batch decoding
offset = np.repeat(np.arange(0, batch_size * self.beam_size, self.beam_size,
dtype='int32', ctx=self.context), self.beam_size)
# locations of each batch item when first dimension is (batch * beam)
batch_indices = np.arange(0, batch_size * self.beam_size, self.beam_size, dtype='int32', ctx=self.context)
first_step_mask = np.full((batch_size * self.beam_size, 1),
fill_value=np.inf, ctx=self.context, dtype=self.dtype)
first_step_mask[batch_indices] = 0.0
# Best word and hypotheses indices across beam search steps from topk operation.
best_hyp_indices_list = [] # type: List[np.ndarray]
best_word_indices_list = [] # type: List[np.ndarray]
lengths = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32')
finished = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype='int32')
# Extending max_output_lengths to shape (batch_size * beam_size, 1)
max_output_lengths = np.repeat(np.expand_dims(max_output_lengths, axis=1), self.beam_size, axis=0)
# scores_accumulated: chosen smallest scores in scores (ascending).
scores_accumulated = np.zeros((batch_size * self.beam_size, 1), ctx=self.context, dtype=self.dtype)
output_vocab_size = self.output_vocab_size
# If using a top-k lexicon, select param rows for logit computation that correspond to the
# target vocab for this sentence.
vocab_slice_ids = None # type: Optional[np.ndarrays]
if restrict_lexicon:
source_words = np.squeeze(np.split(source, self.num_source_factors, axis=2)[0], axis=2)
vocab_slice_ids, output_vocab_size, raw_constraint_list = _get_vocab_slice_ids(restrict_lexicon,
source_words,
raw_constraint_list,
self.eos_id, beam_size=1)
pad_dist = np.full((batch_size * self.beam_size, output_vocab_size - 1),
fill_value=np.inf, ctx=self.context, dtype=self.dtype)
eos_dist = np.full((batch_size * self.beam_size, output_vocab_size),
fill_value=np.inf, ctx=self.context, dtype=self.dtype)
eos_dist[:, C.EOS_ID] = 0
unk_dist = None
if self.prevent_unk:
unk_dist = np.zeros_like(eos_dist)
unk_dist[:, C.UNK_ID] = np.inf # pylint: disable=E1137
# Initialize the beam to track constraint sets, where target-side lexical constraints are present
constraints = constrained.init_batch(raw_constraint_list, self.beam_size, self.bos_id, self.eos_id)
if self.global_avoid_trie or any(raw_avoid_list):
avoid_states = constrained.AvoidBatch(batch_size, self.beam_size,
avoid_list=raw_avoid_list,
global_avoid_trie=self.global_avoid_trie)
avoid_states.consume(best_word_indices[:, 0]) # constraints operate only on primary target factor
# (0) encode source sentence, returns a list
model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length)
# repeat states to beam_size
model_states = _repeat_states(model_states, self.beam_size, self._inference.state_structure())
# repeat estimated_reference_lengths to shape (batch_size * beam_size, 1)
estimated_reference_lengths = np.repeat(estimated_reference_lengths, self.beam_size, axis=0)
# Records items in the beam that are inactive. At the beginning (t==1), there is only one valid or active
# item on the beam for each sentence
inactive = np.zeros((batch_size * self.beam_size, 1), dtype='int32', ctx=self.context)
t = 1
for t in range(1, max_iterations + 1): # max_iterations + 1 required to get correct results
# (1) obtain next predictions and advance models' state
# target_dists: (batch_size * beam_size, target_vocab_size)
target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices,
model_states,
vocab_slice_ids)
# (2) Produces the accumulated cost of target words in each row.
# There is special treatment for finished and inactive rows: inactive rows are inf everywhere;
# finished rows are inf everywhere except column zero, which holds the accumulated model score
scores, lengths = self._update_scores(target_dists,
finished,
inactive,
scores_accumulated,
lengths,
max_output_lengths,
unk_dist,
pad_dist,
eos_dist)
# Mark entries that should be blocked as having a score of np.inf
if self.global_avoid_trie or any(raw_avoid_list):
block_indices = avoid_states.avoid()
if len(block_indices) > 0:
scores[block_indices] = np.inf
if self._sample is not None:
target_dists[block_indices] = np.inf
# (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
# far as the active beam size for each sentence.
if self._sample is not None:
best_hyp_indices, best_word_indices, scores_accumulated = self._sample(scores,
target_dists,
finished,
sample_best_hyp_indices)
else:
# On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions
# of the first row only by setting all other rows to inf
if t == 1:
scores += first_step_mask
best_hyp_indices, best_word_indices, scores_accumulated = self._top(scores, offset)
# Constraints for constrained decoding are processed sentence by sentence
if any(raw_constraint_list):
best_hyp_indices, best_word_indices, scores_accumulated, constraints, inactive = constrained.topk(
t,
batch_size,
self.beam_size,
inactive,
scores,
constraints,
best_hyp_indices,
best_word_indices,
scores_accumulated)
# Map from restricted to full vocab ids if needed
if restrict_lexicon:
best_word_indices = np.take(vocab_slice_ids, best_word_indices, axis=0)
# (4) Normalize the scores of newly finished hypotheses. Note that after this until the
# next call to topk(), hypotheses may not be in sorted order.
_sort_inputs = [best_hyp_indices, best_word_indices, finished, scores_accumulated, lengths,
estimated_reference_lengths]
if target_factors is not None:
_sort_inputs.append(target_factors)
best_word_indices, finished, scores_accumulated, lengths, estimated_reference_lengths = \
self._sort_norm_and_update_finished(*_sort_inputs)
# Collect best hypotheses, best word indices
best_word_indices_list.append(best_word_indices)
best_hyp_indices_list.append(best_hyp_indices)
if self._should_stop(finished, batch_size):
break
# (5) update models' state with winning hypotheses (ascending)
model_states = self._sort_states(best_hyp_indices, *model_states)
logger.debug("Finished after %d out of %d steps.", t, max_iterations)
# (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them).
scores_accumulated_shape = scores_accumulated.shape
folded_accumulated_scores = scores_accumulated.reshape((batch_size, -1))
indices = np.argsort(folded_accumulated_scores.astype('float32', copy=False), axis=1).reshape((-1,))
best_hyp_indices = np.unravel_index(indices, scores_accumulated_shape)[0].astype('int32') + offset
scores_accumulated = scores_accumulated.take(best_hyp_indices, axis=0)
best_hyp_indices_list.append(best_hyp_indices)
lengths = lengths.take(best_hyp_indices, axis=0)
all_best_hyp_indices = np.stack(best_hyp_indices_list, axis=1)
all_best_word_indices = np.stack(best_word_indices_list, axis=2)
constraints = [constraints[x] for x in best_hyp_indices.tolist()]
return all_best_hyp_indices, \
all_best_word_indices, \
scores_accumulated, \
lengths.astype('int32', copy=False), \
estimated_reference_lengths, \
constraints