in sockeye/inference.py [0:0]
def _get_inference_input(self,
trans_inputs: List[TranslatorInput]) -> Tuple[np.ndarray,
int,
Optional[lexicon.TopKLexicon],
List[Optional[constrained.RawConstraintList]],
List[Optional[constrained.RawConstraintList]],
np.ndarray]:
"""
Assembles the numerical data for the batch. This comprises an NDArray for the source sentences,
the bucket key (padded source length), and a list of raw constraint lists, one for each sentence in the batch,
an NDArray of maximum output lengths for each sentence in the batch.
Each raw constraint list contains phrases in the form of lists of integers in the target language vocabulary.
:param trans_inputs: List of TranslatorInputs.
:return ndarray of source ids (shape=(batch_size, bucket_key, num_factors)),
ndarray of valid source lengths, lexicon for vocabulary restriction, list of raw constraint
lists, and list of phrases to avoid, and an ndarray of maximum output
lengths.
"""
batch_size = len(trans_inputs)
lengths = [len(inp) for inp in trans_inputs]
max_length = max(len(inp) for inp in trans_inputs)
# assembling source ids on cpu array (faster) and copy to Translator.context (potentially GPU) in one go below.
source = np.zeros((batch_size, max_length, self.num_source_factors), dtype=np.float32, ctx=context.cpu())
restrict_lexicon = None # type: Optional[lexicon.TopKLexicon]
raw_constraints = [None] * batch_size # type: List[Optional[constrained.RawConstraintList]]
raw_avoid_list = [None] * batch_size # type: List[Optional[constrained.RawConstraintList]]
max_output_lengths = [] # type: List[int]
for j, trans_input in enumerate(trans_inputs):
num_tokens = len(trans_input) # includes eos
max_output_lengths.append(self._get_max_output_length(num_tokens))
source[j, :num_tokens, 0] = data_io.tokens2ids(trans_input.tokens, self.source_vocabs[0])
factors = trans_input.factors if trans_input.factors is not None else []
num_factors = 1 + len(factors)
if num_factors != self.num_source_factors:
logger.warning("Input %d factors, but model(s) expect %d", num_factors,
self.num_source_factors)
for i, factor in enumerate(factors[:self.num_source_factors - 1], start=1):
# fill in as many factors as there are tokens
source[j, :num_tokens, i] = data_io.tokens2ids(factor, self.source_vocabs[i])[:num_tokens]
# Check if vocabulary selection/restriction is enabled:
# - First, see if the translator input provides a lexicon (used for multiple lexicons)
# - If not, see if the translator itself provides a lexicon (used for single lexicon)
# - The same lexicon must be used for all inputs in the batch.
if trans_input.restrict_lexicon is not None:
if restrict_lexicon is not None and restrict_lexicon is not trans_input.restrict_lexicon:
logger.warning("Sentence %s: different restrict_lexicon specified, will overrule previous. "
"All inputs in batch must use same lexicon." % trans_input.sentence_id)
restrict_lexicon = trans_input.restrict_lexicon
elif self.restrict_lexicon is not None:
if isinstance(self.restrict_lexicon, dict):
# This code should not be reachable since the case is checked when creating
# translator inputs. It is included here to guarantee that the translator can
# handle any valid input regardless of whether it was checked at creation time.
logger.warning("Sentence %s: no restrict_lexicon specified for input when using multiple lexicons, "
"defaulting to first lexicon for entire batch." % trans_input.sentence_id)
restrict_lexicon = list(self.restrict_lexicon.values())[0]
else:
restrict_lexicon = self.restrict_lexicon
if trans_input.constraints is not None:
raw_constraints[j] = [data_io.tokens2ids(phrase, self.vocab_targets[0]) for phrase in
trans_input.constraints]
if trans_input.avoid_list is not None:
raw_avoid_list[j] = [data_io.tokens2ids(phrase, self.vocab_targets[0]) for phrase in
trans_input.avoid_list]
if any(self.unk_id in phrase for phrase in raw_avoid_list[j]):
logger.warning("Sentence %s: %s was found in the list of phrases to avoid; "
"this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL)
source = np.array(source, ctx=self.context)
source_length = np.array(lengths, ctx=self.context, dtype=self.dtype) # shape: (batch_size,)
max_output_lengths = np.array(max_output_lengths, ctx=self.context, dtype='int32')
return source, source_length, restrict_lexicon, raw_constraints, raw_avoid_list, max_output_lengths