in sockeye/inference.py [0:0]
def __init__(self,
context: context.Context,
ensemble_mode: str,
scorer: CandidateScorer,
batch_size: int,
beam_search_stop: str,
models: List[SockeyeModel],
source_vocabs: List[vocab.Vocab],
target_vocabs: List[vocab.Vocab],
beam_size: int = 5,
nbest_size: int = 1,
restrict_lexicon: Optional[Union[lexicon.TopKLexicon, Dict[str, lexicon.TopKLexicon]]] = None,
avoid_list: Optional[str] = None,
strip_unknown_words: bool = False,
sample: int = None,
output_scores: bool = False,
constant_length_ratio: float = 0.0,
hybridize: bool = True,
max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
max_input_length: Optional[int] = None,
max_output_length: Optional[int] = None,
softmax_temperature: Optional[float] = None,
prevent_unk: bool = False,
greedy: bool = False) -> None:
self.context = context
self.dtype = C.DTYPE_FP32 if models[0].dtype == C.DTYPE_INT8 else models[0].dtype
self._scorer = scorer
self.batch_size = batch_size
self.beam_size = beam_size
self.beam_search_stop = beam_search_stop
self.source_vocabs = source_vocabs
self.vocab_targets = target_vocabs
self.vocab_targets_inv = [vocab.reverse_vocab(v) for v in self.vocab_targets]
self.restrict_lexicon = restrict_lexicon
assert C.PAD_ID == 0, "pad id should be 0"
self.stop_ids = {C.EOS_ID, C.PAD_ID} # type: Set[int]
self.strip_ids = self.stop_ids.copy() # ids to strip from the output
self.unk_id = C.UNK_ID
if strip_unknown_words:
self.strip_ids.add(self.unk_id)
self.models = models
# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
# set a common max_output length for all models.
self._max_input_length, self._get_max_output_length = models_max_input_output_length(
models,
max_output_length_num_stds,
forced_max_input_length=max_input_length,
forced_max_output_length=max_output_length)
self.nbest_size = nbest_size
utils.check_condition(self.beam_size >= nbest_size, 'nbest_size must be smaller or equal to beam_size.')
if self.nbest_size > 1:
utils.check_condition(self.beam_search_stop == C.BEAM_SEARCH_STOP_ALL,
"nbest_size > 1 requires beam_search_stop to be set to 'all'")
self._search = get_search_algorithm(
models=self.models,
beam_size=self.beam_size,
context=self.context,
vocab_target=target_vocabs[0], # only primary target factor used for constrained decoding.
output_scores=output_scores,
sample=sample,
ensemble_mode=ensemble_mode,
beam_search_stop=beam_search_stop,
scorer=self._scorer,
constant_length_ratio=constant_length_ratio,
avoid_list=avoid_list,
hybridize=hybridize,
softmax_temperature=softmax_temperature,
prevent_unk=prevent_unk,
greedy=greedy)
self._concat_translations = partial(_concat_nbest_translations if self.nbest_size > 1 else _concat_translations,
stop_ids=self.stop_ids,
scorer=self._scorer) # type: Callable
logger.info("Translator (%d model(s) beam_size=%d algorithm=%s, beam_search_stop=%s max_input_length=%s "
"nbest_size=%s ensemble_mode=%s max_batch_size=%d avoiding=%d dtype=%s softmax_temperature=%s)",
len(self.models),
self.beam_size,
"GreedySearch" if isinstance(self._search, GreedySearch) else "BeamSearch",
self.beam_search_stop,
self.max_input_length,
self.nbest_size,
"None" if len(self.models) == 1 else ensemble_mode,
self.max_batch_size,
0 if self._search.global_avoid_trie is None else len(self._search.global_avoid_trie),
self.dtype,
softmax_temperature)