def run_translate()

in sockeye/translate.py [0:0]


def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning("For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                           C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type,
                                        args.output)
    hybridize = not args.no_hybridization

    with ExitStack() as exit_stack:
        check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
        context = determine_context(device_ids=args.device_ids,
                                    use_cpu=args.use_cpu,
                                    disable_device_locking=args.disable_device_locking,
                                    lock_dir=args.lock_dir,
                                    exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)
        from sockeye.model import load_models
        models, source_vocabs, target_vocabs = load_models(context=context,
                                                           model_folders=args.models,
                                                           checkpoints=args.checkpoints,
                                                           dtype=args.dtype,
                                                           hybridize=hybridize,
                                                           inference_only=True,
                                                           mc_dropout=args.mc_dropout)

        restrict_lexicon = None  # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
        if args.restrict_lexicon is not None:
            logger.info(str(args.restrict_lexicon))
            if len(args.restrict_lexicon) == 1:
                # Single lexicon used for all inputs.
                restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
                # Handle a single arg of key:path or path (parsed as path:path)
                restrict_lexicon.load(args.restrict_lexicon[0][1], k=args.restrict_lexicon_topk)
            else:
                check_condition(args.json_input,
                                "JSON input is required when using multiple lexicons for vocabulary restriction")
                # Multiple lexicons with specified names
                restrict_lexicon = dict()
                for key, path in args.restrict_lexicon:
                    lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
                    lexicon.load(path, k=args.restrict_lexicon_topk)
                    restrict_lexicon[key] = lexicon

        brevity_penalty_weight = args.brevity_penalty_weight
        if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
            if args.brevity_penalty_constant_length_ratio > 0.0:
                constant_length_ratio = args.brevity_penalty_constant_length_ratio
            else:
                constant_length_ratio = sum(model.length_ratio_mean for model in models) / len(models)
                logger.info("Using average of constant length ratios saved in the model configs: %f",
                            constant_length_ratio)
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
            constant_length_ratio = -1.0
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
            brevity_penalty_weight = 0.0
            constant_length_ratio = -1.0
        else:
            raise ValueError("Unknown brevity penalty type %s" % args.brevity_penalty_type)


        scorer = inference.CandidateScorer(
            length_penalty_alpha=args.length_penalty_alpha,
            length_penalty_beta=args.length_penalty_beta,
            brevity_penalty_weight=brevity_penalty_weight)

        translator = inference.Translator(context=context,
                                          ensemble_mode=args.ensemble_mode,
                                          scorer=scorer,
                                          batch_size=args.batch_size,
                                          beam_size=args.beam_size,
                                          beam_search_stop=args.beam_search_stop,
                                          nbest_size=args.nbest_size,
                                          models=models,
                                          source_vocabs=source_vocabs,
                                          target_vocabs=target_vocabs,
                                          restrict_lexicon=restrict_lexicon,
                                          avoid_list=args.avoid_list,
                                          strip_unknown_words=args.strip_unknown_words,
                                          sample=args.sample,
                                          output_scores=output_handler.reports_score(),
                                          constant_length_ratio=constant_length_ratio,
                                          max_output_length_num_stds=args.max_output_length_num_stds,
                                          max_input_length=args.max_input_length,
                                          max_output_length=args.max_output_length,
                                          hybridize=hybridize,
                                          softmax_temperature=args.softmax_temperature,
                                          prevent_unk=args.prevent_unk,
                                          greedy=args.greedy)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)