in access/fairseq/base.py [0:0]
def _fairseq_generate(complex_filepath,
output_pred_filepath,
checkpoint_paths,
complex_dictionary_path,
simple_dictionary_path,
beam=5,
hypothesis_num=1,
lenpen=1.,
diverse_beam_groups=None,
diverse_beam_strength=0.5,
sampling=False,
batch_size=128):
# exp_dir must contain checkpoints/checkpoint_best.pt, and dict.{complex,simple}.txt
# First copy input complex file to exp_dir and create dummy simple file
tmp_dir = Path(tempfile.mkdtemp())
new_complex_filepath = tmp_dir / 'tmp.complex-simple.complex'
dummy_simple_filepath = tmp_dir / 'tmp.complex-simple.simple'
shutil.copy(complex_filepath, new_complex_filepath)
shutil.copy(complex_filepath, dummy_simple_filepath)
shutil.copy(complex_dictionary_path, tmp_dir / 'dict.complex.txt')
shutil.copy(simple_dictionary_path, tmp_dir / 'dict.simple.txt')
generate_parser = options.get_generation_parser()
args = [
tmp_dir,
'--path',
':'.join([str(path) for path in checkpoint_paths]),
'--beam',
beam,
'--nbest',
hypothesis_num,
'--lenpen',
lenpen,
'--diverse-beam-groups',
diverse_beam_groups if diverse_beam_groups is not None else -1,
'--diverse-beam-strength',
diverse_beam_strength,
'--batch-size',
batch_size,
'--raw-text',
'--print-alignment',
'--gen-subset',
'tmp',
# We don't want to reload pretrained embeddings
'--model-overrides',
{
'encoder_embed_path': None,
'decoder_embed_path': None
},
]
if sampling:
args.extend([
'--sampling',
'--sampling-topk',
10,
])
args = [str(arg) for arg in args]
generate_args = options.parse_args_and_arch(generate_parser, args)
out_filepath = tmp_dir / 'generation.out'
with log_stdout(out_filepath, mute_stdout=True):
# evaluate model in batch mode
generate.main(generate_args)
# Retrieve translations
def parse_all_hypotheses(out_filepath):
hypotheses_dict = defaultdict(list)
for line in yield_lines(out_filepath):
match = re.match(r'^H-(\d+)\t-?\d+\.\d+\t(.*)$', line)
if match:
sample_id, hypothesis = match.groups()
hypotheses_dict[int(sample_id)].append(hypothesis)
# Sort in original order
return [hypotheses_dict[i] for i in range(len(hypotheses_dict))]
all_hypotheses = parse_all_hypotheses(out_filepath)
predictions = [hypotheses[hypothesis_num - 1] for hypotheses in all_hypotheses]
write_lines(predictions, output_pred_filepath)
os.remove(dummy_simple_filepath)
os.remove(new_complex_filepath)