in python/vmaf/routine.py [0:0]
def run_test_on_dataset(test_dataset, runner_class, ax,
result_store, model_filepath,
parallelize=True, fifo_mode=True,
aggregate_method=np.mean,
type='regressor',
**kwargs):
test_assets = read_dataset(test_dataset, **kwargs)
test_raw_assets = None
try:
for test_asset in test_assets:
assert test_asset.groundtruth is not None
except AssertionError:
# no groundtruth, try to do subjective modeling
from sureal.dataset_reader import RawDatasetReader
from sureal.subjective_model import DmosModel
subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
subjective_model = subj_model_class(dataset_reader_class(test_dataset))
subjective_model.run_modeling(**kwargs)
test_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
test_raw_assets = test_assets
test_assets = read_dataset(test_dataset_aggregate, **kwargs)
optional_dict = kwargs['optional_dict'] if 'optional_dict' in kwargs else None
if model_filepath is not None:
if not optional_dict:
optional_dict = {}
optional_dict['model_filepath'] = model_filepath
if 'model_720_filepath' in kwargs and kwargs['model_720_filepath'] is not None:
optional_dict['720model_filepath'] = kwargs['model_720_filepath']
if 'model_480_filepath' in kwargs and kwargs['model_480_filepath'] is not None:
optional_dict['480model_filepath'] = kwargs['model_480_filepath']
if 'model_2160_filepath' in kwargs and kwargs['model_2160_filepath'] is not None:
optional_dict['2160model_filepath'] = kwargs['model_2160_filepath']
if 'enable_transform_score' in kwargs and kwargs['enable_transform_score'] is not None:
if not optional_dict:
optional_dict = {}
optional_dict['enable_transform_score'] = kwargs['enable_transform_score']
if 'disable_clip_score' in kwargs and kwargs['disable_clip_score'] is not None:
if not optional_dict:
optional_dict = {}
optional_dict['disable_clip_score'] = kwargs['disable_clip_score']
if 'subsample' in kwargs and kwargs['subsample'] is not None:
if not optional_dict:
optional_dict = {}
optional_dict['subsample'] = kwargs['subsample']
if 'additional_optional_dict' in kwargs and kwargs['additional_optional_dict'] is not None:
assert isinstance(kwargs['additional_optional_dict'], dict)
if not optional_dict:
optional_dict = {}
optional_dict.update(kwargs['additional_optional_dict'])
if 'processes' in kwargs and kwargs['processes'] is not None:
assert isinstance(kwargs['processes'], int)
processes = kwargs['processes']
else:
processes = None
if processes is not None:
assert parallelize is True, 'if processes is not None, parallelize must be True'
# run
runner = runner_class(
test_assets,
None, fifo_mode=fifo_mode,
delete_workdir=True,
result_store=result_store,
optional_dict=optional_dict,
optional_dict2=None,
)
runner.run(parallelize=parallelize, processes=processes)
results = runner.results
for result in results:
result.set_score_aggregate_method(aggregate_method)
try:
model_type = runner.get_train_test_model_class()
except:
if type == 'regressor':
model_type = RegressorMixin
elif type == 'classifier':
model_type = ClassifierMixin
else:
assert False
split_test_indices_for_perf_ci = kwargs['split_test_indices_for_perf_ci'] \
if 'split_test_indices_for_perf_ci' in kwargs else False
# plot
groundtruths = list(map(lambda asset: asset.groundtruth, test_assets))
predictions = list(map(lambda result: result[runner_class.get_score_key()], results))
raw_grountruths = None if test_raw_assets is None else \
list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))
groundtruths_std = None if test_assets is None else \
list(map(lambda asset: asset.groundtruth_std, test_assets))
try:
predictions_bagging = list(map(lambda result: result[runner_class.get_bagging_score_key()], results))
predictions_stddev = list(map(lambda result: result[runner_class.get_stddev_score_key()], results))
predictions_ci95_low = list(map(lambda result: result[runner_class.get_ci95_low_score_key()], results))
predictions_ci95_high = list(map(lambda result: result[runner_class.get_ci95_high_score_key()], results))
predictions_all_models = list(map(lambda result: result[runner_class.get_all_models_score_key()], results))
# need to revert the list of lists, so that the outer list has the predictions for each model separately
predictions_all_models = np.array(predictions_all_models).T.tolist()
num_models = np.shape(predictions_all_models)[0]
stats = model_type.get_stats(groundtruths, predictions,
ys_label_raw=raw_grountruths,
ys_label_pred_bagging=predictions_bagging,
ys_label_pred_stddev=predictions_stddev,
ys_label_pred_ci95_low=predictions_ci95_low,
ys_label_pred_ci95_high=predictions_ci95_high,
ys_label_pred_all_models=predictions_all_models,
ys_label_stddev=groundtruths_std,
split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
except Exception as e:
print('Warning: stats calculation failed, fall back to default stats calculation: {}'.format(e))
stats = model_type.get_stats(groundtruths, predictions,
ys_label_raw=raw_grountruths,
ys_label_stddev=groundtruths_std,
split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
num_models = 1
print('Stats on testing data: {}'.format(model_type.format_stats_for_print(stats)))
# printing stats if multiple models are present
if 'SRCC_across_model_distribution' in stats \
and 'PCC_across_model_distribution' in stats \
and 'RMSE_across_model_distribution' in stats:
print('Stats on testing data (across multiple models, using all test indices): {}'.format(
model_type.format_across_model_stats_for_print(model_type.extract_across_model_stats(stats))))
if split_test_indices_for_perf_ci:
print('Stats on testing data (single model, multiple test sets): {}'
.format(model_type.format_stats_across_test_splits_for_print(model_type.extract_across_test_splits_stats(stats))))
if ax is not None:
content_ids = list(map(lambda asset: asset.content_id, test_assets))
if 'point_label' in kwargs and kwargs['point_label'] is not None:
if kwargs['point_label'] == 'asset_id':
point_labels = list(map(lambda asset: asset.asset_id, test_assets))
elif kwargs['point_label'] == 'dis_path':
point_labels = list(map(lambda asset: get_file_name_without_extension(asset.dis_path), test_assets))
else:
raise AssertionError("Unknown point_label {}".format(kwargs['point_label']))
else:
point_labels = None
model_type.plot_scatter(ax, stats, content_ids=content_ids, point_labels=point_labels, **kwargs)
ax.set_xlabel('True Score')
ax.set_ylabel("Predicted Score")
ax.grid()
ax.set_title("{runner}{num_models}\n{stats}".format(
dataset=test_assets[0].dataset,
runner=runner_class.TYPE,
stats=model_type.format_stats_for_plot(stats),
num_models=", {} models".format(num_models) if num_models > 1 else "",
))
return test_assets, results