in pytorch_translate/generate.py [0:0]
def _generate_score(models, args, task, dataset, modify_target_dict):
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
if not args.quiet:
print(
"| loading model(s) from {}".format(
", ".join(args.path.split(CHECKPOINT_PATHS_DELIMITER))
)
)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=True,
)
translator = build_sequence_generator(args, task, models)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print("seed number is" + str(args.max_examples_to_evaluate_seed))
if args.max_examples_to_evaluate > 0:
pytorch_translate_data.subsample_pair_dataset(
dataset, args.max_examples_to_evaluate, args.max_examples_to_evaluate_seed
)
# Keep track of translations
# Initialize with empty translations
# and zero probs scores
translated_sentences = [""] * len(dataset)
translated_scores = [0.0] * len(dataset)
hypos_list = []
collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
if collect_output_hypos:
output_hypos_token_arrays = [None] * len(dataset)
# Generate and compute BLEU score
dst_dict = task.target_dictionary
if args.sacrebleu:
scorer = bleu.SacrebleuScorer(bleu.SacrebleuConfig())
else:
scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dst_dict.pad(),
eos=dst_dict.eos(),
unk=dst_dict.unk(),
)
)
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
oracle_scorer = None
if args.report_oracle_bleu:
oracle_scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dst_dict.pad(),
eos=dst_dict.eos(),
unk=dst_dict.unk(),
)
)
rescorer = None
num_sentences = 0
translation_samples = []
translation_info_list = []
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t,
maxlen_a=args.max_len_a,
maxlen_b=args.max_len_b,
cuda=use_cuda,
timer=gen_timer,
prefix_size=1
if pytorch_translate_data.is_multilingual_many_to_one(args)
else 0,
)
for trans_info in _iter_translations(
args, task, dataset, translations, align_dict, rescorer, modify_target_dict
):
if hasattr(scorer, "add_string"):
scorer.add_string(trans_info.target_str, trans_info.hypo_str)
else:
scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
if oracle_scorer is not None:
oracle_scorer.add(trans_info.target_tokens, trans_info.best_hypo_tokens)
if getattr(args, "translation_output_file", False):
translated_sentences[trans_info.sample_id] = trans_info.hypo_str
if getattr(args, "translation_probs_file", False):
translated_scores[trans_info.sample_id] = trans_info.hypo_score
if getattr(args, "hypotheses_export_path", False):
hypos_list.append(trans_info.hypos)
if collect_output_hypos:
output_hypos_token_arrays[
trans_info.sample_id
] = trans_info.best_hypo_tokens
if args.translation_info_export_path is not None:
# Strip expensive data from hypotheses before saving
hypos = [
{k: v for k, v in hypo.items() if k in ["tokens", "score"]}
for hypo in trans_info.hypos
]
# Make sure everything is on cpu before exporting
hypos = [
{"score": hypo["score"], "tokens": hypo["tokens"].cpu()}
for hypo in hypos
]
translation_info_list.append(
{
"src_tokens": trans_info.src_tokens.cpu(),
"target_tokens": trans_info.target_tokens,
"hypos": hypos,
}
)
translation_samples.append(
collections.OrderedDict(
{
"sample_id": trans_info.sample_id.item(),
"src_str": trans_info.src_str,
"target_str": trans_info.target_str,
"hypo_str": trans_info.hypo_str,
}
)
)
wps_meter.update(trans_info.src_tokens.size(0))
t.log({"wps": round(wps_meter.avg)})
num_sentences += 1
# If applicable, save collected hypothesis tokens to binary output file
if collect_output_hypos:
output_dataset = pytorch_translate_data.InMemoryIndexedDataset()
output_dataset.load_from_sequences(output_hypos_token_arrays)
output_dataset.save(args.output_hypos_binary_path)
if args.output_source_binary_path:
dataset.src.save(args.output_source_binary_path)
if args.translation_info_export_path is not None:
f = open(args.translation_info_export_path, "wb")
pickle.dump(translation_info_list, f)
f.close()
# If applicable, save the translations and scores to the output files
# These two ouputs are used in dual learning for weighted backtranslation
if getattr(args, "translation_output_file", False) and getattr(
args, "translation_probs_file", False
):
with open(args.translation_output_file, "w") as translation_file, open(
args.translation_probs_file, "w"
) as score_file:
for hypo_str, hypo_score in zip(translated_sentences, translated_scores):
if len(hypo_str.strip()) > 0:
print(hypo_str, file=translation_file)
print(np.exp(hypo_score), file=score_file)
# For eg. external evaluation
if getattr(args, "hypotheses_export_path", False):
with open(args.hypotheses_export_path, "w") as out_file:
for hypos in hypos_list:
for hypo in hypos:
print(
task.tgt_dict.string(
hypo["tokens"], bpe_symbol=args.post_process
),
file=out_file,
)
if oracle_scorer is not None:
print(f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}")
return scorer, num_sentences, gen_timer, translation_samples