in generate.py [0:0]
def main(args):
assert args.path is not None, '--path required for generation!'
args.beam = args.nbest = 1
args.max_tokens = int(1e4)
utils.import_user_module(args)
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
src_dict = getattr(task, 'source_dictionary', None)
tgt_dict = task.target_dictionary
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=args.beam,
need_attn=False
)
model.cuda()
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
generator = task.build_generator(args)
with progress_bar.build_progress_bar(args, itr) as t:
for sample in t:
sample = utils.move_to_cuda(sample)
if 'net_input' not in sample:
continue
prefix_tokens = None
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
for i, sample_id in enumerate(sample['id'].tolist()):
# Remove padding
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
else:
src_str = ""
# Process top predictions
hypo = hypos[i][0]
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'],
align_dict=None,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result = dict(src=src_str, pred=hypo_str, src_len=len(src_str.split()), pred_len=len(hypo_str.split()))
result_line = json.dumps(result)
print(result_line)