def translate()

in sockeye/inference.py [0:0]


    def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = True) -> List[TranslatorOutput]:
        """
        Batch-translates a list of TranslatorInputs, returns a list of TranslatorOutputs.
        Empty or bad inputs are skipped.
        Splits inputs longer than Translator.max_input_length into segments of size max_input_length,
        and then groups segments into batches of at most Translator.max_batch_size.
        Too-long segments that were split are reassembled into a single output after translation.
        If fill_up_batches is set to True, underfilled batches are padded to Translator.max_batch_size, otherwise
        dynamic batch sizing is used, which comes at increased memory usage.

        :param trans_inputs: List of TranslatorInputs as returned by make_input().
        :param fill_up_batches: If True, underfilled batches are padded to Translator.max_batch_size.
        :return: List of translation results.
        """
        num_inputs = len(trans_inputs)
        translated_chunks = []  # type: List[IndexedTranslation]

        # split into chunks
        input_chunks = []  # type: List[IndexedTranslatorInput]
        for trans_input_idx, trans_input in enumerate(trans_inputs):
            # bad input
            if isinstance(trans_input, BadTranslatorInput):
                translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
                                                            translation=empty_translation(add_nbest=(self.nbest_size > 1))))
            # empty input
            elif len(trans_input.tokens) == 0:
                translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
                                                            translation=empty_translation(add_nbest=(self.nbest_size > 1))))
            else:
                if len(trans_input.tokens) > self.max_input_length:
                    # oversized input
                    logger.debug(
                        "Input %s has length (%d) that exceeds max input length (%d). "
                        "Splitting into chunks of size %d.",
                        trans_input.sentence_id, len(trans_input.tokens),
                        self.max_input_length, self.max_input_length)
                    chunks = [trans_input_chunk.with_eos()
                              for trans_input_chunk in
                              trans_input.chunks(self.max_input_length)]
                    input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input)
                                         for chunk_idx, chunk_input in enumerate(chunks)])
                else:
                    # regular input
                    input_chunks.append(IndexedTranslatorInput(trans_input_idx,
                                                               chunk_idx=0,
                                                               translator_input=trans_input.with_eos()))

            if trans_input.constraints is not None:
                logger.info("Input %s has %d %s: %s", trans_input.sentence_id,
                            len(trans_input.constraints),
                            "constraint" if len(trans_input.constraints) == 1 else "constraints",
                            ", ".join(" ".join(x) for x in trans_input.constraints))

        num_bad_empty = len(translated_chunks)

        # Sort longest to shortest (to rather fill batches of shorter than longer sequences)
        input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.translator_input.tokens), reverse=True)
        # translate in batch-sized blocks over input chunks
        batch_size = self.max_batch_size if fill_up_batches else min(len(input_chunks), self.max_batch_size)

        num_batches = 0
        for batch_id, batch in enumerate(utils.grouper(input_chunks, batch_size)):
            logger.debug("Translating batch %d", batch_id)

            rest = batch_size - len(batch)
            if fill_up_batches and rest > 0:
                logger.debug("Padding batch of size %d to full batch size (%d)", len(batch), batch_size)
                batch = batch + [batch[0]] * rest

            translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
            batch_translations = self._translate_np(*self._get_inference_input(translator_inputs))

            # truncate to remove filler translations
            if fill_up_batches and rest > 0:
                batch_translations = batch_translations[:-rest]

            for chunk, translation in zip(batch, batch_translations):
                translated_chunks.append(IndexedTranslation(chunk.input_idx, chunk.chunk_idx, translation))
            num_batches += 1
        # Sort by input idx and then chunk id
        translated_chunks = sorted(translated_chunks)
        num_chunks = len(translated_chunks)

        # Concatenate results
        results = []  # type: List[TranslatorOutput]
        chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.input_idx)
        for trans_input, (input_idx, translations_for_input_idx) in zip(trans_inputs, chunks_by_input_idx):
            translations_for_input_idx = list(translations_for_input_idx)  # type: ignore
            if len(translations_for_input_idx) == 1:  # type: ignore
                translation = translations_for_input_idx[0].translation  # type: ignore
            else:
                translations_to_concat = [translated_chunk.translation
                                          for translated_chunk in translations_for_input_idx]
                translation = self._concat_translations(translations_to_concat)

            results.append(self._make_result(trans_input, translation))

        num_outputs = len(results)

        logger.debug("Translated %d inputs (%d chunks) in %d batches to %d outputs. %d empty/bad inputs.",
                     num_inputs, num_chunks, num_batches, num_outputs, num_bad_empty)

        return results