def run_translate()

in sockeye/translate_pt.py [0:0]


def run_translate(args: argparse.Namespace):
    # Seed randomly unless a seed has been passed
    seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type,
                                        args.output)

    use_cpu = args.use_cpu
    if not pt.cuda.is_available():
        logger.info("CUDA not available, using cpu")
        use_cpu = True
    device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id)
    logger.info(f"Translate Device: {device}")
    models, source_vocabs, target_vocabs = load_models(device=device,
                                                       model_folders=args.models,
                                                       checkpoints=args.checkpoints,
                                                       dtype=args.dtype,
                                                       inference_only=True,
                                                       mc_dropout=args.mc_dropout)

    restrict_lexicon = None  # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
    if args.restrict_lexicon is not None:
        logger.info(str(args.restrict_lexicon))
        if len(args.restrict_lexicon) == 1:
            # Single lexicon used for all inputs.
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
            # Handle a single arg of key:path or path (parsed as path:path)
            restrict_lexicon.load(args.restrict_lexicon[0][1], k=args.restrict_lexicon_topk)
        else:
            check_condition(args.json_input,
                            "JSON input is required when using multiple lexicons for vocabulary restriction")
            # Multiple lexicons with specified names
            restrict_lexicon = dict()
            for key, path in args.restrict_lexicon:
                lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
                lexicon.load(path, k=args.restrict_lexicon_topk)
                restrict_lexicon[key] = lexicon

    brevity_penalty_weight = args.brevity_penalty_weight
    if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
        if args.brevity_penalty_constant_length_ratio > 0.0:
            constant_length_ratio = args.brevity_penalty_constant_length_ratio
        else:
            constant_length_ratio = sum(model.length_ratio_mean for model in models) / len(models)
            logger.info("Using average of constant length ratios saved in the model configs: %f",
                        constant_length_ratio)
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
        constant_length_ratio = -1.0
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
        brevity_penalty_weight = 0.0
        constant_length_ratio = -1.0
    else:
        raise ValueError("Unknown brevity penalty type %s" % args.brevity_penalty_type)

    for model in models:
        model.eval()

    scorer = inference_pt.CandidateScorer(
        length_penalty_alpha=args.length_penalty_alpha,
        length_penalty_beta=args.length_penalty_beta,
        brevity_penalty_weight=brevity_penalty_weight)
    scorer.to(models[0].dtype)

    translator = inference_pt.Translator(device=device,
                                         ensemble_mode=args.ensemble_mode,
                                         scorer=scorer,
                                         batch_size=args.batch_size,
                                         beam_size=args.beam_size,
                                         beam_search_stop=args.beam_search_stop,
                                         nbest_size=args.nbest_size,
                                         models=models,
                                         source_vocabs=source_vocabs,
                                         target_vocabs=target_vocabs,
                                         restrict_lexicon=restrict_lexicon,
                                         strip_unknown_words=args.strip_unknown_words,
                                         sample=args.sample,
                                         output_scores=output_handler.reports_score(),
                                         constant_length_ratio=constant_length_ratio,
                                         max_output_length_num_stds=args.max_output_length_num_stds,
                                         max_input_length=args.max_input_length,
                                         max_output_length=args.max_output_length,
                                         softmax_temperature=args.softmax_temperature,
                                         prevent_unk=args.prevent_unk,
                                         greedy=args.greedy)

    read_and_translate(translator=translator,
                       output_handler=output_handler,
                       chunk_size=args.chunk_size,
                       input_file=args.input,
                       input_factors=args.input_factors,
                       input_is_json=args.json_input)