def finetune_and_predict_on_dataset()

in muss/fairseq/main.py [0:0]


def finetune_and_predict_on_dataset(finetuning_dataset, exp_dir, **kwargs):
    kwargs['train_kwargs']['ngpus'] = 1
    prefix = 'finetune'
    if kwargs.get('fast_parametrization_search', False):
        prefix += '_fast'
    pred_filepaths = [
        exp_dir / f'{prefix}_{finetuning_dataset}_valid-test_{finetuning_dataset}_valid.pred',
        exp_dir / f'{prefix}_{finetuning_dataset}_valid-test_{finetuning_dataset}_test.pred',
    ]
    if all([path.exists() for path in pred_filepaths]):
        return
    for phase, pred_filepath in zip(['valid', 'test'], pred_filepaths):
        orig_sents_path = get_data_filepath(finetuning_dataset, phase, 'complex')
        refs_sents_paths = list(get_dataset_dir(finetuning_dataset).glob(f'{phase}.simple*'))
        kwargs['evaluate_kwargs'] = {
            'test_set': 'custom',
            'orig_sents_path': orig_sents_path,
            'refs_sents_paths': refs_sents_paths,
        }
        if phase == 'valid':
            # Finetune preprocessors_kwargs only on valid
            kwargs['preprocessors_kwargs'] = find_best_parametrization(exp_dir, **kwargs)
        shutil.copyfile(fairseq_get_simplifier(exp_dir, **kwargs)(orig_sents_path), pred_filepath)