def compute()

in training/NeuralPatternMatchingTraining.py [0:0]


def compute(model_string, dataset_string, highlighted_set_path, non_highlighted_set_path, validation_set_path, test_set_path, no_Lfs, train_size, boostrap_split, train_split, validation_split, unlabelled_split,
            results_folder, score_to_optimize, dim_target=2, rationale_noise=0):

    if dataset_string == 'NREHatespeechDataset':
        dataset_class = NREHatespeechDataset
    elif dataset_string == 'NRESpouseDataset':
        dataset_class = NRESpouseDataset
    elif dataset_string == 'NREMovieReviewDataset':
        dataset_class = NREMovieReviewDataset
    else:
        raise

    highlighted_set = dataset_class(highlighted_set_path, rationale_noise=rationale_noise)
    non_highlighted_set = dataset_class(non_highlighted_set_path, rationale_noise=rationale_noise)

    test_set = dataset_class(test_set_path)

    all_train_dataset = ConcatDataset((highlighted_set, non_highlighted_set))

    train_set = Subset(all_train_dataset, train_split)

    if validation_set_path is not None:
        validation_set = dataset_class(validation_set_path)
        # Validation split is not interesting if we have an explicit validation set
        unlabelled_set = Subset(all_train_dataset, validation_split + unlabelled_split)
    else:
        validation_set = Subset(all_train_dataset, validation_split)
        unlabelled_set = Subset(all_train_dataset, unlabelled_split)

    best_vl_scores = {'accuracy': 0.,
                      'precision': 0.,
                      'recall': 0.,
                      'f1': 0.}

    te_scores = {
        'best_params': {},
        'best_vl_scores': {},
        'test_scores': {'accuracy': 0,
                        'precision': 0,
                        'recall': 0,
                        'f1': 0}
    }

    if 'fasttext' in highlighted_set_path:
        embedding_dim = 300
    else:
        embedding_dim = 768

    for hpb in [float(np.exp(1)), 5, 10]:
        for lr in [1e-2]:
            for l1 in [1e-2, 1e-3]:
                for l2 in [1e-3, 1e-4]:
                    for no_prototypes in [5, 10]:
                        for num_epochs in [500]:
                            for gating_param in [10, 100]:


                                models = bagging(model_string, no_Lfs, train_set, \
                                                 lr=lr, l1_coeff=l1, l2=l2, max_epochs=num_epochs, \
                                                 no_prototypes=no_prototypes, embedding_dim=embedding_dim,
                                                 gating_param=gating_param, batch_size=32,
                                                 save_path=None, highlights_pow_base=hpb)

                                # Compute prediction for each NRE
                                for dataset_type, dataset in [('train', train_set), ('validation', validation_set),
                                                              ('test', test_set), ('unlabelled', unlabelled_set)]:

                                    if dataset_type == 'unlabelled' and no_Lfs == 1:
                                        continue

                                    # IMPORTANT: SHUFFLE MUST STAY FALSE (SEE TARGETS LATER)
                                    loader = DataLoader(dataset, batch_size=256, collate_fn=custom_collate, shuffle=False,
                                                        num_workers=0)
                                    predictions, _ = compute_predictions_for_DP(models, loader,
                                                                                save_path=Path(results_folder,
                                                                                               'stored_results',
                                                                                               f'predictions_{train_size}_size_{dataset_type}_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'))

                                # Compute and store targets on different files
                                for dataset, dataset_type in [(train_set, 'train'), (validation_set, 'validation'),
                                                              (test_set, 'test')]:

                                    dataset_loader = DataLoader(dataset, batch_size=256, collate_fn=custom_collate,
                                                                shuffle=False,
                                                                num_workers=2)

                                    all_targets = None
                                    for _, _, _, targets, _ in dataset_loader:
                                        targets, _ = targets

                                        if all_targets is None:
                                            all_targets = targets
                                        else:
                                            all_targets = torch.cat((all_targets, targets), dim=0)

                                    if not os.path.exists(Path(results_folder, 'stored_results')):
                                        os.makedirs(Path(results_folder, 'stored_results'))

                                    torch.save(all_targets,
                                               Path(results_folder, 'stored_results', f'all_targets_{dataset_type}_{train_size}_size_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'))

                                all_targets_valid = torch.load(
                                    Path(results_folder, 'stored_results', f'all_targets_validation_{train_size}_size_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'))
                                all_targets_test = torch.load(
                                    Path(results_folder, 'stored_results', f'all_targets_test_{train_size}_size_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'))

                                # For Data Programming
                                all_targets_valid_score = np.copy(all_targets_valid)  # It will be used to compute scores
                                all_targets_valid[all_targets_valid == 0] = 2
                                targets_valid = all_targets_valid.numpy()

                                all_targets_test_score = np.copy(all_targets_test)  # It will be used to compute scores
                                all_targets_test[all_targets_test == 0] = 2

                                train_predictions = torch.load(Path(results_folder, 'stored_results',
                                                                    f'predictions_{train_size}_size_train_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'),
                                                               map_location='cpu')

                                if no_Lfs != 1:
                                    unlabelled_predictions = torch.load(Path(results_folder, 'stored_results',
                                                                         f'predictions_{train_size}_size_unlabelled_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'),
                                                                    map_location='cpu')

                                valid_predictions = torch.load(Path(results_folder, 'stored_results',
                                                                    f'predictions_{train_size}_size_validation_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'),
                                                               map_location='cpu')

                                test_predictions = torch.load(Path(results_folder, 'stored_results',
                                                                   f'predictions_{train_size}_size_test_{no_Lfs}_rules_bootstrap_{boostrap_split}.torch'),
                                                              map_location='cpu')

                                train_predictions = train_predictions.cpu().reshape(-1, dim_target)

                                if no_Lfs != 1:
                                    unlabelled_predictions = unlabelled_predictions.cpu().reshape(-1, dim_target)

                                valid_predictions = valid_predictions.cpu().reshape(-1, dim_target)
                                test_predictions = test_predictions.cpu().reshape(-1, dim_target)

                                for threshold in [0.01, 0.05]:

                                    Ls_train = process_outputs(train_predictions, no_Lfs, threshold)
                                    if no_Lfs != 1:
                                        Ls_unlabelled = process_outputs(unlabelled_predictions, no_Lfs, threshold)
                                    Ls_valid = process_outputs(valid_predictions, no_Lfs, threshold)
                                    Ls_test = process_outputs(test_predictions, no_Lfs, threshold)

                                    if no_Lfs != 1:
                                        # Concatenate train and "unlabelled" data
                                        Ls_dataset = np.concatenate((Ls_train, Ls_unlabelled), axis=0)

                                        search_space = {
                                            'n_epochs': [100, 500],
                                            'lr': {'range': [0.01, 0.001], 'scale': 'log'},
                                            'show_plots': True,
                                        }

                                        tuner = RandomSearchTuner(LabelModelNoSeed)  # , seed=123)

                                        # ------------ DANGER ZONE: be careful here! ------------ #

                                        # Train on train+unlabelled because it is unsupervised (exploit unlabelled data), and "optimize" on
                                        # small validation set

                                        label_aggregator = tuner.search(
                                            search_space,
                                            train_args=[Ls_dataset],
                                            X_dev=Ls_valid, Y_dev=targets_valid.squeeze(),
                                            max_search=10, verbose=False, metric=score_to_optimize,
                                            shuffle=False
                                            # Leave it False, ow gen_splits generates different splits compared to linear baseline
                                        )

                                        Y_vl = label_aggregator.predict(Ls_valid)
                                        Y_test = label_aggregator.predict(Ls_test)

                                    # ------------ END OF DANGER ZONE ------------ #

                                    else:
                                        Y_vl = Ls_valid[:, 0]
                                        Y_test = Ls_test[:, 0]

                                    Y_vl[Y_vl == 2] = 0
                                    Y_test[Y_test == 2] = 0

                                    vl_pr = precision_score(all_targets_valid_score, Y_vl) * 100
                                    vl_rec = recall_score(all_targets_valid_score, Y_vl) * 100
                                    vl_acc = accuracy_score(all_targets_valid_score, Y_vl) * 100
                                    vl_f1 = f1_score(all_targets_valid_score, Y_vl) * 100

                                    te_pr = precision_score(all_targets_test_score, Y_test) * 100
                                    te_rec = recall_score(all_targets_test_score, Y_test) * 100
                                    te_acc = accuracy_score(all_targets_test_score, Y_test) * 100
                                    te_f1 = f1_score(all_targets_test_score, Y_test) * 100

                                    vl_scores = {'accuracy': float(vl_acc),
                                                 'precision': float(vl_pr),
                                                 'recall': float(vl_rec),
                                                 'f1': float(vl_f1)
                                                 }

                                    if vl_scores[score_to_optimize] > best_vl_scores[score_to_optimize]:
                                        best_vl_scores = deepcopy(vl_scores)
                                        best_params = deepcopy(
                                            {'learning_rate': lr, 'l1': l1, 'l2': l2,
                                             'train_split': train_split,
                                             'validation_split': validation_split,
                                             'no_prototypes': no_prototypes,
                                             'gating_param': gating_param,
                                             'threshold': threshold,
                                             'error_multiplier': hpb,
                                             'epochs': num_epochs})

                                        te_scores['best_params'] = best_params
                                        te_scores['best_vl_scores'] = best_vl_scores
                                        te_scores['test_scores'] = {'accuracy': float(te_acc),
                                                                    'precision': float(te_pr),
                                                                    'recall': float(te_rec),
                                                                    'f1': float(te_f1)}

    print(f'Best VL scores found is {best_vl_scores}')
    print(f'Best TE scores found is {te_scores["test_scores"]}')
    print(f'End of model assessment for train size {train_size} and {no_Lfs} rules, test results are {te_scores}')

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    with open(
            Path(results_folder, f'NRE_size_{train_size}_rules_{no_Lfs}_test_results_bootstrap_{boostrap_split}.json'), 'w') as f:
        json.dump(te_scores, f)