in sockeye/inference_pt.py [0:0]
def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = True) -> List[TranslatorOutput]:
"""
Batch-translates a list of TranslatorInputs, returns a list of TranslatorOutputs.
Empty or bad inputs are skipped.
Splits inputs longer than Translator.max_input_length into segments of size max_input_length,
and then groups segments into batches of at most Translator.max_batch_size.
Too-long segments that were split are reassembled into a single output after translation.
If fill_up_batches is set to True, underfilled batches are padded to Translator.max_batch_size, otherwise
dynamic batch sizing is used, which comes at increased memory usage.
:param trans_inputs: List of TranslatorInputs as returned by make_input().
:param fill_up_batches: If True, underfilled batches are padded to Translator.max_batch_size.
:return: List of translation results.
"""
num_inputs = len(trans_inputs)
translated_chunks = [] # type: List[IndexedTranslation]
# split into chunks
input_chunks = [] # type: List[IndexedTranslatorInput]
for trans_input_idx, trans_input in enumerate(trans_inputs):
# bad input
if isinstance(trans_input, BadTranslatorInput):
translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
translation=empty_translation(add_nbest=(self.nbest_size > 1))))
# empty input
elif len(trans_input.tokens) == 0:
translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
translation=empty_translation(add_nbest=(self.nbest_size > 1))))
else:
if len(trans_input.tokens) > self.max_input_length:
# oversized input
logger.debug(
"Input %s has length (%d) that exceeds max input length (%d). "
"Splitting into chunks of size %d.",
trans_input.sentence_id, len(trans_input.tokens),
self.max_input_length, self.max_input_length)
chunks = [trans_input_chunk.with_eos()
for trans_input_chunk in
trans_input.chunks(self.max_input_length)]
input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input)
for chunk_idx, chunk_input in enumerate(chunks)])
else:
# regular input
input_chunks.append(IndexedTranslatorInput(trans_input_idx,
chunk_idx=0,
translator_input=trans_input.with_eos()))
if trans_input.constraints is not None:
logger.info("Input %s has %d %s: %s", trans_input.sentence_id,
len(trans_input.constraints),
"constraint" if len(trans_input.constraints) == 1 else "constraints",
", ".join(" ".join(x) for x in trans_input.constraints))
num_bad_empty = len(translated_chunks)
# Sort longest to shortest (to rather fill batches of shorter than longer sequences)
input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.translator_input.tokens), reverse=True)
# translate in batch-sized blocks over input chunks
batch_size = self.max_batch_size if fill_up_batches else min(len(input_chunks), self.max_batch_size)
num_batches = 0
for batch_id, batch in enumerate(utils.grouper(input_chunks, batch_size)):
logger.debug("Translating batch %d", batch_id)
rest = batch_size - len(batch)
if fill_up_batches and rest > 0:
logger.debug("Padding batch of size %d to full batch size (%d)", len(batch), batch_size)
batch = batch + [batch[0]] * rest
translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
with pt.inference_mode():
batch_translations = self._translate_np(*self._get_inference_input(translator_inputs)) # type: ignore
# truncate to remove filler translations
if fill_up_batches and rest > 0:
batch_translations = batch_translations[:-rest]
for chunk, translation in zip(batch, batch_translations):
translated_chunks.append(IndexedTranslation(chunk.input_idx, chunk.chunk_idx, translation))
num_batches += 1
# Sort by input idx and then chunk id
translated_chunks = sorted(translated_chunks)
num_chunks = len(translated_chunks)
# Concatenate results
results = [] # type: List[TranslatorOutput]
chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.input_idx)
for trans_input, (input_idx, translations_for_input_idx) in zip(trans_inputs, chunks_by_input_idx):
translations_for_input_idx = list(translations_for_input_idx) # type: ignore
if len(translations_for_input_idx) == 1: # type: ignore
translation = translations_for_input_idx[0].translation # type: ignore
else:
translations_to_concat = [translated_chunk.translation
for translated_chunk in translations_for_input_idx]
translation = self._concat_translations(translations_to_concat)
results.append(self._make_result(trans_input, translation))
num_outputs = len(results)
logger.debug("Translated %d inputs (%d chunks) in %d batches to %d outputs. %d empty/bad inputs.",
num_inputs, num_chunks, num_batches, num_outputs, num_bad_empty)
return results