in generate_cmlm.py [0:0]
def main(args):
assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
torch.manual_seed(args.seed)
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
#src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
dict = tgt_dict
# Load decoding strategy
strategy = strategies.setup_strategy(args)
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
models = [model.cuda() for model in models]
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
results = []
scorer = pybleu.PyBleuScorer()
num_sentences = 0
has_target = True
timer = TimeMeter()
with progress_bar.build_progress_bar(args, itr) as t:
translations = generate_batched_itr(t, strategy, models, tgt_dict, length_beam_size=args.length_beam, use_gold_target_len=args.gold_target_len)
for sample_id, src_tokens, target_tokens, hypos in translations:
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
src_str = dict.string(src_tokens, args.remove_bpe)
if args.dehyphenate:
src_str = dehyphenate(src_str)
if has_target:
target_str = dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if args.dehyphenate:
target_str = dehyphenate(target_str)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypos.int().cpu(),
src_str=src_str,
alignment= None,
align_dict=align_dict,
tgt_dict=dict,
remove_bpe=args.remove_bpe,
)
if args.dehyphenate:
hypo_str = dehyphenate(hypo_str)
if not args.quiet:
print('H-{}\t{}'.format(sample_id, hypo_str))
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
print()
# Score only the top hypothesis
if has_target:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
results.append((target_str, hypo_str))
num_sentences += 1
if has_target:
print('Time = {}'.format(timer.elapsed_time))
ref, out = zip(*results)
print('| Generate {} with beam={}: BLEU4 = {:2.2f}, '.format(args.gen_subset, args.beam, scorer.score(ref, out)))