def _fairseq_generate()

in access/fairseq/base.py [0:0]


def _fairseq_generate(complex_filepath,
                      output_pred_filepath,
                      checkpoint_paths,
                      complex_dictionary_path,
                      simple_dictionary_path,
                      beam=5,
                      hypothesis_num=1,
                      lenpen=1.,
                      diverse_beam_groups=None,
                      diverse_beam_strength=0.5,
                      sampling=False,
                      batch_size=128):
    # exp_dir must contain checkpoints/checkpoint_best.pt, and dict.{complex,simple}.txt
    # First copy input complex file to exp_dir and create dummy simple file
    tmp_dir = Path(tempfile.mkdtemp())
    new_complex_filepath = tmp_dir / 'tmp.complex-simple.complex'
    dummy_simple_filepath = tmp_dir / 'tmp.complex-simple.simple'
    shutil.copy(complex_filepath, new_complex_filepath)
    shutil.copy(complex_filepath, dummy_simple_filepath)
    shutil.copy(complex_dictionary_path, tmp_dir / 'dict.complex.txt')
    shutil.copy(simple_dictionary_path, tmp_dir / 'dict.simple.txt')
    generate_parser = options.get_generation_parser()
    args = [
        tmp_dir,
        '--path',
        ':'.join([str(path) for path in checkpoint_paths]),
        '--beam',
        beam,
        '--nbest',
        hypothesis_num,
        '--lenpen',
        lenpen,
        '--diverse-beam-groups',
        diverse_beam_groups if diverse_beam_groups is not None else -1,
        '--diverse-beam-strength',
        diverse_beam_strength,
        '--batch-size',
        batch_size,
        '--raw-text',
        '--print-alignment',
        '--gen-subset',
        'tmp',
        # We don't want to reload pretrained embeddings
        '--model-overrides',
        {
            'encoder_embed_path': None,
            'decoder_embed_path': None
        },
    ]
    if sampling:
        args.extend([
            '--sampling',
            '--sampling-topk',
            10,
        ])
    args = [str(arg) for arg in args]
    generate_args = options.parse_args_and_arch(generate_parser, args)
    out_filepath = tmp_dir / 'generation.out'
    with log_stdout(out_filepath, mute_stdout=True):
        # evaluate model in batch mode
        generate.main(generate_args)
    # Retrieve translations

    def parse_all_hypotheses(out_filepath):
        hypotheses_dict = defaultdict(list)
        for line in yield_lines(out_filepath):
            match = re.match(r'^H-(\d+)\t-?\d+\.\d+\t(.*)$', line)
            if match:
                sample_id, hypothesis = match.groups()
                hypotheses_dict[int(sample_id)].append(hypothesis)
        # Sort in original order
        return [hypotheses_dict[i] for i in range(len(hypotheses_dict))]

    all_hypotheses = parse_all_hypotheses(out_filepath)
    predictions = [hypotheses[hypothesis_num - 1] for hypotheses in all_hypotheses]
    write_lines(predictions, output_pred_filepath)
    os.remove(dummy_simple_filepath)
    os.remove(new_complex_filepath)