def run_test_on_dataset()

in python/vmaf/routine.py [0:0]


def run_test_on_dataset(test_dataset, runner_class, ax,
                    result_store, model_filepath,
                    parallelize=True, fifo_mode=True,
                    aggregate_method=np.mean,
                    type='regressor',
                    **kwargs):

    test_assets = read_dataset(test_dataset, **kwargs)
    test_raw_assets = None
    try:
        for test_asset in test_assets:
            assert test_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try to do subjective modeling
        from sureal.dataset_reader import RawDatasetReader
        from sureal.subjective_model import DmosModel
        subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
        dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
        subjective_model = subj_model_class(dataset_reader_class(test_dataset))
        subjective_model.run_modeling(**kwargs)
        test_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
        test_raw_assets = test_assets
        test_assets = read_dataset(test_dataset_aggregate, **kwargs)

    optional_dict = kwargs['optional_dict'] if 'optional_dict' in kwargs else None

    if model_filepath is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['model_filepath'] = model_filepath
        if 'model_720_filepath' in kwargs and kwargs['model_720_filepath'] is not None:
            optional_dict['720model_filepath'] = kwargs['model_720_filepath']
        if 'model_480_filepath' in kwargs and kwargs['model_480_filepath'] is not None:
            optional_dict['480model_filepath'] = kwargs['model_480_filepath']
        if 'model_2160_filepath' in kwargs and kwargs['model_2160_filepath'] is not None:
            optional_dict['2160model_filepath'] = kwargs['model_2160_filepath']

    if 'enable_transform_score' in kwargs and kwargs['enable_transform_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['enable_transform_score'] = kwargs['enable_transform_score']

    if 'disable_clip_score' in kwargs and kwargs['disable_clip_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['disable_clip_score'] = kwargs['disable_clip_score']

    if 'subsample' in kwargs and kwargs['subsample'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['subsample'] = kwargs['subsample']

    if 'additional_optional_dict' in kwargs and kwargs['additional_optional_dict'] is not None:
        assert isinstance(kwargs['additional_optional_dict'], dict)
        if not optional_dict:
            optional_dict = {}
        optional_dict.update(kwargs['additional_optional_dict'])

    if 'processes' in kwargs and kwargs['processes'] is not None:
        assert isinstance(kwargs['processes'], int)
        processes = kwargs['processes']
    else:
        processes = None
    if processes is not None:
        assert parallelize is True, 'if processes is not None, parallelize must be True'

    # run
    runner = runner_class(
        test_assets,
        None, fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize, processes=processes)
    results = runner.results

    for result in results:
        result.set_score_aggregate_method(aggregate_method)

    try:
        model_type = runner.get_train_test_model_class()
    except:
        if type == 'regressor':
            model_type = RegressorMixin
        elif type == 'classifier':
            model_type = ClassifierMixin
        else:
            assert False

    split_test_indices_for_perf_ci = kwargs['split_test_indices_for_perf_ci'] \
        if 'split_test_indices_for_perf_ci' in kwargs else False

    # plot
    groundtruths = list(map(lambda asset: asset.groundtruth, test_assets))
    predictions = list(map(lambda result: result[runner_class.get_score_key()], results))
    raw_grountruths = None if test_raw_assets is None else \
        list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))
    groundtruths_std = None if test_assets is None else \
        list(map(lambda asset: asset.groundtruth_std, test_assets))
    try:
        predictions_bagging = list(map(lambda result: result[runner_class.get_bagging_score_key()], results))
        predictions_stddev = list(map(lambda result: result[runner_class.get_stddev_score_key()], results))
        predictions_ci95_low = list(map(lambda result: result[runner_class.get_ci95_low_score_key()], results))
        predictions_ci95_high = list(map(lambda result: result[runner_class.get_ci95_high_score_key()], results))
        predictions_all_models = list(map(lambda result: result[runner_class.get_all_models_score_key()], results))

        # need to revert the list of lists, so that the outer list has the predictions for each model separately
        predictions_all_models = np.array(predictions_all_models).T.tolist()
        num_models = np.shape(predictions_all_models)[0]

        stats = model_type.get_stats(groundtruths, predictions,
                                     ys_label_raw=raw_grountruths,
                                     ys_label_pred_bagging=predictions_bagging,
                                     ys_label_pred_stddev=predictions_stddev,
                                     ys_label_pred_ci95_low=predictions_ci95_low,
                                     ys_label_pred_ci95_high=predictions_ci95_high,
                                     ys_label_pred_all_models=predictions_all_models,
                                     ys_label_stddev=groundtruths_std,
                                     split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
    except Exception as e:
        print('Warning: stats calculation failed, fall back to default stats calculation: {}'.format(e))
        stats = model_type.get_stats(groundtruths, predictions,
                                     ys_label_raw=raw_grountruths,
                                     ys_label_stddev=groundtruths_std,
                                     split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
        num_models = 1

    print('Stats on testing data: {}'.format(model_type.format_stats_for_print(stats)))

    # printing stats if multiple models are present
    if 'SRCC_across_model_distribution' in stats \
            and 'PCC_across_model_distribution' in stats \
            and 'RMSE_across_model_distribution' in stats:
        print('Stats on testing data (across multiple models, using all test indices): {}'.format(
            model_type.format_across_model_stats_for_print(model_type.extract_across_model_stats(stats))))

    if split_test_indices_for_perf_ci:
        print('Stats on testing data (single model, multiple test sets): {}'
              .format(model_type.format_stats_across_test_splits_for_print(model_type.extract_across_test_splits_stats(stats))))

    if ax is not None:
        content_ids = list(map(lambda asset: asset.content_id, test_assets))

        if 'point_label' in kwargs and kwargs['point_label'] is not None:
            if kwargs['point_label'] == 'asset_id':
                point_labels = list(map(lambda asset: asset.asset_id, test_assets))
            elif kwargs['point_label'] == 'dis_path':
                point_labels = list(map(lambda asset: get_file_name_without_extension(asset.dis_path), test_assets))
            else:
                raise AssertionError("Unknown point_label {}".format(kwargs['point_label']))
        else:
            point_labels = None

        model_type.plot_scatter(ax, stats, content_ids=content_ids, point_labels=point_labels, **kwargs)
        ax.set_xlabel('True Score')
        ax.set_ylabel("Predicted Score")
        ax.grid()
        ax.set_title("{runner}{num_models}\n{stats}".format(
            dataset=test_assets[0].dataset,
            runner=runner_class.TYPE,
            stats=model_type.format_stats_for_plot(stats),
            num_models=", {} models".format(num_models) if num_models > 1 else "",
        ))

    return test_assets, results