def compute()

in training/NGramLogRegTraining.py [0:0]


def compute(dataset_string, highlighted_train_set_path, non_highlighted_train_set_path, validation_set_path, test_set_path, no_runs, train_size, boostrap_split, train_split, validation_split, unlabelled_split,
            results_folder, score_to_optimize, dim_target=2):

    model_class = LogisticRegression

    if dataset_string == 'NGramsDataset':
        dataset_class = NGramsDataset
    else:
        raise

    highlighted_train_set = dataset_class(highlighted_train_set_path)
    non_highlighted_train_set = dataset_class(non_highlighted_train_set_path)
    test_set = dataset_class(test_set_path)

    all_train_dataset = ConcatDataset((highlighted_train_set, non_highlighted_train_set))

    train_set = Subset(all_train_dataset, train_split)

    if validation_set_path is not None:
        validation_set = dataset_class(validation_set_path)
        # Validation split is not interesting if we have an explicit validation set
        unlabelled_set = Subset(all_train_dataset, validation_split + unlabelled_split)
    else:
        validation_set = Subset(all_train_dataset, validation_split)
        unlabelled_set = Subset(all_train_dataset, unlabelled_split)

    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=custom_collate, shuffle=False)

    best_vl_scores = {'accuracy': 0,
                      'precision': 0,
                      'recall': 0,
                      'f1': 0}

    # These are our hyper-parameters
    best_params = None
    for learning_rate in [1e-2, 1e-3, 1e-4]:
        for weight_decay in [1e-1, 1e-2, 1e-4]:
            for num_epochs in [50, 100, 500]:

                vl_scores = {'accuracy': 0,
                             'precision': 0,
                             'recall': 0,
                             'f1': 0}

                for run in range(no_runs):
                    train_loader = DataLoader(train_set, batch_size=batch_size,
                                              collate_fn=custom_collate, shuffle=True)
                    valid_loader = DataLoader(validation_set, batch_size=batch_size,
                                              collate_fn=custom_collate, shuffle=False)

                    model = model_class(input_size, dim_target)

                    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
                    # gamma = decaying factor
                    scheduler = StepLR(optimizer, step_size=50, gamma=1.)  # Useless scheduler

                    epoch_losses = train_model(train_loader, model, optimizer, scheduler,
                                               max_epochs=num_epochs, device=device)

                    vl_acc, vl_pr, vl_rec, vl_f1, _ = predict(valid_loader, model, device)

                    vl_scores['accuracy'] = vl_scores['accuracy'] + float(vl_acc)
                    vl_scores['precision'] = vl_scores['precision'] + float(vl_pr)
                    vl_scores['recall'] = vl_scores['recall'] + float(vl_rec)
                    vl_scores['f1'] = vl_scores['f1'] + float(vl_f1)

                # AVERAGE OVER RUNS
                for key in ['accuracy', 'precision', 'recall', 'f1']:
                    vl_scores[key] = vl_scores[key] / no_runs

                if vl_scores[score_to_optimize] > best_vl_scores[score_to_optimize]:
                    best_vl_scores = deepcopy(vl_scores)
                    best_params = deepcopy(
                        {'learning_rate': learning_rate,
                         'train_split': train_split, 'valid_split': validation_split,
                         'weight_decay': weight_decay, 'epochs': num_epochs})

    te_scores = {
        'best_params': best_params,
        'best_vl_scores': best_vl_scores,
        'test_scores': {'accuracy': 0,
                        'precision': 0,
                        'recall': 0,
                        'f1': 0}
    }

    for run in range(no_runs):
        model = model_class(input_size, dim_target)

        optimizer = torch.optim.Adam(model.parameters(), lr=best_params['learning_rate'],
                                     weight_decay=best_params['weight_decay'])
        scheduler = StepLR(optimizer, step_size=50, gamma=1.)  # Useless scheduler for now

        epoch_losses = train_model(train_loader, model, optimizer, scheduler,
                                   max_epochs=best_params['epochs'], device=device)
        te_acc, te_pr, te_rec, te_f1, _ = predict(test_loader, model, device)

        te_scores['test_scores']['accuracy'] = te_scores['test_scores']['accuracy'] + float(te_acc)
        te_scores['test_scores']['precision'] = te_scores['test_scores']['precision'] + float(te_pr)
        te_scores['test_scores']['recall'] = te_scores['test_scores']['recall'] + float(te_rec)
        te_scores['test_scores']['f1'] = te_scores['test_scores']['f1'] + float(te_f1)

    # AVERAGE OVER RUNS
    for key in ['accuracy', 'precision', 'recall', 'f1']:
        te_scores['test_scores'][key] = te_scores['test_scores'][key] / no_runs

    print(f'Best VL scores found is {best_vl_scores}')
    print(f'Best TE scores found is {te_scores["test_scores"]}')
    print(f'End of model assessment for train size {train_size}, test results are {te_scores}')

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    with open(Path(results_folder, f'BertBaseline_size_{train_size}_runs_{no_runs}_test_results_bootstrap_{boostrap_split}.json'), 'w') as f:
        json.dump(te_scores, f)