def compute()

in training/BertBaselineTraining.py [0:0]


def compute(model_string, dataset_string, highlighted_train_set_path, non_highlighted_train_set_path, validation_set_path, test_set_path, no_runs, train_size, boostrap_split, train_split, validation_split, unlabelled_split,
            results_folder, score_to_optimize, dim_target=2):

    if model_string == 'LogisticRegressionOnTokens':
        model_class = LogisticRegressionOnTokens
    elif model_string == 'MLPOnTokens':
        model_class = MLPOnTokens
    elif model_string == 'NBOW':
        model_class = NBOW
    elif model_string == 'NBOW2':
        model_class = NBOW2
    elif model_string == 'DAN':
        model_class = DAN
    elif model_string == 'MLPOnTokensWithHighlights':
        model_class = MLPOnTokensWithHighlights
    else:
        raise

    if dataset_string == 'BertBaselineTokensDataset':
        dataset_class = BertBaselineTokensDataset
    elif dataset_string == 'SpouseMLPDataset':
        dataset_class = SpouseMLPDataset
    else:
        raise

    highlighted_train_set = dataset_class(highlighted_train_set_path)
    non_highlighted_train_set = dataset_class(non_highlighted_train_set_path)
    test_set = dataset_class(test_set_path)

    all_train_dataset = ConcatDataset((highlighted_train_set, non_highlighted_train_set))

    train_set = Subset(all_train_dataset, train_split)

    if validation_set_path is not None:
        validation_set = dataset_class(validation_set_path)
        # Validation split is not interesting if we have an explicit validation set
        unlabelled_set = Subset(all_train_dataset, validation_split + unlabelled_split)
    else:
        validation_set = Subset(all_train_dataset, validation_split)
        unlabelled_set = Subset(all_train_dataset, unlabelled_split)

    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=custom_collate, shuffle=False)

    best_vl_scores = {'accuracy': 0,
                      'precision': 0,
                      'recall': 0,
                      'f1': 0}

    # These are our hyper-parameters
    best_params = None

    if 'NBOW' in model_string:
        hidden_units = [0.]
        highlight_pow_bases = [None]
        tokens_dropout = [None]
    if model_string == 'DAN':
        hidden_units = [8, 32]
        highlight_pow_bases = [None]
        tokens_dropout = [0.5, 0.3, 0.]
    elif model_string == 'MLPOnTokens':
        hidden_units = [8, 32]
        highlight_pow_bases = [None]
        tokens_dropout = [None]
    elif model_string == 'MLPOnTokensWithHighlights':
        hidden_units = [8, 32]
        highlight_pow_bases = [float(np.exp(1)), 5, 10]
        tokens_dropout = [None]
    else:
        hidden_units = [None]
        tokens_dropout = [None]
        highlight_pow_bases = [None]

    for t_d in tokens_dropout:
        for hpb in highlight_pow_bases:
            for hidden in hidden_units:
                for learning_rate in [1e-2, 1e-3]:
                    for weight_decay in [1e-2, 1e-4]:
                        for num_epochs in [100, 500]:

                            vl_scores = {'accuracy': 0,
                                         'precision': 0,
                                         'recall': 0,
                                         'f1': 0}

                            for run in range(no_runs):
                                train_loader = DataLoader(train_set, batch_size=batch_size,
                                                          collate_fn=custom_collate, shuffle=True)
                                valid_loader = DataLoader(validation_set, batch_size=batch_size,
                                                          collate_fn=custom_collate, shuffle=False)

                                if hidden is not None:
                                    if hpb is not None:
                                        assert model_string == 'MLPOnTokensWithHighlights'
                                        model = model_class(input_size, dim_target, hidden, highlights_pow_base=hpb)  # For MLP ablation study
                                    elif t_d is None:
                                        assert model_string == 'MLPOnTokens' or model_string == 'NBOW'
                                        model = model_class(input_size, dim_target, hidden)
                                    elif t_d is not None:
                                        assert model_string == 'DAN'
                                        model = model_class(input_size, dim_target, hidden, t_d)

                                else:
                                    model = model_class(input_size, dim_target)

                                optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
                                # gamma = decaying factor
                                scheduler = StepLR(optimizer, step_size=50, gamma=1.)  # Useless scheduler

                                epoch_losses = train_model(train_loader, model, optimizer, scheduler,
                                                           max_epochs=num_epochs, device=device)

                                vl_acc, vl_pr, vl_rec, vl_f1, _ = predict(valid_loader, model, device)

                                vl_scores['accuracy'] = vl_scores['accuracy'] + float(vl_acc)
                                vl_scores['precision'] = vl_scores['precision'] + float(vl_pr)
                                vl_scores['recall'] = vl_scores['recall'] + float(vl_rec)
                                vl_scores['f1'] = vl_scores['f1'] + float(vl_f1)

                            # AVERAGE OVER RUNS
                            for key in ['accuracy', 'precision', 'recall', 'f1']:
                                vl_scores[key] = vl_scores[key] / no_runs

                            if vl_scores[score_to_optimize] > best_vl_scores[score_to_optimize]:
                                best_vl_scores = deepcopy(vl_scores)
                                best_params = deepcopy(
                                    {'learning_rate': learning_rate,
                                     'train_split': train_split, 'valid_split': validation_split,
                                     'weight_decay': weight_decay, 'epochs': num_epochs,
                                     'error_base': hpb,
                                     'tokens_dropout': t_d,
                                     'hidden_units': hidden})

    te_scores = {
        'best_params': best_params,
        'best_vl_scores': best_vl_scores,
        'test_scores': {'accuracy': 0,
                        'precision': 0,
                        'recall': 0,
                        'f1': 0}
    }

    for run in range(no_runs):
        if te_scores['best_params']['hidden_units'] is not None:
            if te_scores['best_params']['error_base'] is not None:
                assert model_string == 'MLPOnTokensWithHighlights'
                model = model_class(input_size, dim_target, te_scores['best_params']['hidden_units'], highlights_pow_base=te_scores['best_params']['error_base'])  # For MLP ablation study
            elif te_scores['best_params']['tokens_dropout'] is None:
                assert model_string == 'MLPOnTokens' or model_string == 'NBOW'
                model = model_class(input_size, dim_target, te_scores['best_params']['hidden_units'])  # For MLP ablation study
            elif te_scores['best_params']['tokens_dropout'] is not None:
                assert model_string == 'DAN'
                print('final on DAN')
                model = model_class(input_size, dim_target, te_scores['best_params']['hidden_units'], te_scores['best_params']['tokens_dropout'])
        else:
            model = model_class(input_size, dim_target)

        optimizer = torch.optim.Adam(model.parameters(), lr=best_params['learning_rate'],
                                     weight_decay=best_params['weight_decay'])
        scheduler = StepLR(optimizer, step_size=50, gamma=1.)  # Useless scheduler for now

        epoch_losses = train_model(train_loader, model, optimizer, scheduler,
                                   max_epochs=best_params['epochs'], device=device)
        te_acc, te_pr, te_rec, te_f1, _ = predict(test_loader, model, device)

        te_scores['test_scores']['accuracy'] = te_scores['test_scores']['accuracy'] + float(te_acc)
        te_scores['test_scores']['precision'] = te_scores['test_scores']['precision'] + float(te_pr)
        te_scores['test_scores']['recall'] = te_scores['test_scores']['recall'] + float(te_rec)
        te_scores['test_scores']['f1'] = te_scores['test_scores']['f1'] + float(te_f1)

    # AVERAGE OVER RUNS
    for key in ['accuracy', 'precision', 'recall', 'f1']:
        te_scores['test_scores'][key] = te_scores['test_scores'][key] / no_runs

    print(f'Best VL scores found is {best_vl_scores}')
    print(f'Best TE scores found is {te_scores["test_scores"]}')
    print(f'End of model assessment for train size {train_size}, test results are {te_scores}')

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    with open(Path(results_folder, f'BertBaseline_size_{train_size}_runs_{no_runs}_test_results_bootstrap_{boostrap_split}.json'), 'w') as f:
        json.dump(te_scores, f)