in sockeye/translate.py [0:0]
def run_translate(args: argparse.Namespace):
# Seed randomly unless a seed has been passed
utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))
if args.output is not None:
setup_main_logger(console=not args.quiet,
file_logging=not args.no_logfile,
path="%s.%s" % (args.output, C.LOG_NAME),
level=args.loglevel)
else:
setup_main_logger(file_logging=False, level=args.loglevel)
log_basic_info(args)
if args.nbest_size > 1:
if args.output_type != C.OUTPUT_HANDLER_JSON:
logger.warning("For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
C.OUTPUT_HANDLER_JSON, args.output_type)
args.output_type = C.OUTPUT_HANDLER_JSON
output_handler = get_output_handler(args.output_type,
args.output)
hybridize = not args.no_hybridization
with ExitStack() as exit_stack:
check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
context = determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)[0]
logger.info("Translate Device: %s", context)
from sockeye.model import load_models
models, source_vocabs, target_vocabs = load_models(context=context,
model_folders=args.models,
checkpoints=args.checkpoints,
dtype=args.dtype,
hybridize=hybridize,
inference_only=True,
mc_dropout=args.mc_dropout)
restrict_lexicon = None # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
if args.restrict_lexicon is not None:
logger.info(str(args.restrict_lexicon))
if len(args.restrict_lexicon) == 1:
# Single lexicon used for all inputs.
restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
# Handle a single arg of key:path or path (parsed as path:path)
restrict_lexicon.load(args.restrict_lexicon[0][1], k=args.restrict_lexicon_topk)
else:
check_condition(args.json_input,
"JSON input is required when using multiple lexicons for vocabulary restriction")
# Multiple lexicons with specified names
restrict_lexicon = dict()
for key, path in args.restrict_lexicon:
lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
lexicon.load(path, k=args.restrict_lexicon_topk)
restrict_lexicon[key] = lexicon
brevity_penalty_weight = args.brevity_penalty_weight
if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
if args.brevity_penalty_constant_length_ratio > 0.0:
constant_length_ratio = args.brevity_penalty_constant_length_ratio
else:
constant_length_ratio = sum(model.length_ratio_mean for model in models) / len(models)
logger.info("Using average of constant length ratios saved in the model configs: %f",
constant_length_ratio)
elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
constant_length_ratio = -1.0
elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
brevity_penalty_weight = 0.0
constant_length_ratio = -1.0
else:
raise ValueError("Unknown brevity penalty type %s" % args.brevity_penalty_type)
scorer = inference.CandidateScorer(
length_penalty_alpha=args.length_penalty_alpha,
length_penalty_beta=args.length_penalty_beta,
brevity_penalty_weight=brevity_penalty_weight)
translator = inference.Translator(context=context,
ensemble_mode=args.ensemble_mode,
scorer=scorer,
batch_size=args.batch_size,
beam_size=args.beam_size,
beam_search_stop=args.beam_search_stop,
nbest_size=args.nbest_size,
models=models,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
restrict_lexicon=restrict_lexicon,
avoid_list=args.avoid_list,
strip_unknown_words=args.strip_unknown_words,
sample=args.sample,
output_scores=output_handler.reports_score(),
constant_length_ratio=constant_length_ratio,
max_output_length_num_stds=args.max_output_length_num_stds,
max_input_length=args.max_input_length,
max_output_length=args.max_output_length,
hybridize=hybridize,
softmax_temperature=args.softmax_temperature,
prevent_unk=args.prevent_unk,
greedy=args.greedy)
read_and_translate(translator=translator,
output_handler=output_handler,
chunk_size=args.chunk_size,
input_file=args.input,
input_factors=args.input_factors,
input_is_json=args.json_input)