in python/vmaf/routine.py [0:0]
def train_test_vmaf_on_dataset(train_dataset, test_dataset,
feature_param, model_param,
train_ax, test_ax, result_store,
logger=None, fifo_mode=True,
output_model_filepath=None,
aggregate_method=np.mean,
**kwargs):
train_assets = read_dataset(train_dataset, **kwargs)
train_raw_assets = None
try:
for train_asset in train_assets:
assert train_asset.groundtruth is not None
except AssertionError:
# no groundtruth, try to do subjective modeling
from sureal.dataset_reader import RawDatasetReader
from sureal.subjective_model import DmosModel
subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
subjective_model = subj_model_class(dataset_reader_class(train_dataset))
subjective_model.run_modeling(**kwargs)
train_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
train_raw_assets = train_assets
train_assets = read_dataset(train_dataset_aggregate, **kwargs)
parallelize = kwargs['parallelize'] if 'parallelize' in kwargs else True
isinstance(parallelize, bool)
processes = kwargs['processes'] if 'processes' in kwargs else None
if processes is not None:
assert isinstance(processes, int) and processes > 0
if processes is not None:
assert parallelize is True, 'if processes is not None, parallelize must be True'
assert hasattr(feature_param, 'feature_dict')
feature_dict = feature_param.feature_dict
feature_option_dict = feature_param.feature_optional_dict if hasattr(feature_param, 'feature_optional_dict') else None
train_fassembler = FeatureAssembler(
feature_dict=feature_dict,
feature_option_dict=feature_option_dict,
assets=train_assets,
logger=logger,
fifo_mode=fifo_mode,
delete_workdir=True,
result_store=result_store,
optional_dict=None, # WARNING: feature param not passed
optional_dict2=None,
parallelize=parallelize,
processes=processes,
)
train_fassembler.run()
train_features = train_fassembler.results
for result in train_features:
result.set_score_aggregate_method(aggregate_method)
model_type = model_param.model_type
model_param_dict = model_param.model_param_dict
model_class = TrainTestModel.find_subclass(model_type)
train_xys = model_class.get_xys_from_results(train_features)
train_xs = model_class.get_xs_from_results(train_features)
train_ys = model_class.get_ys_from_results(train_features)
model = model_class(model_param_dict, logger)
model.train(train_xys, feature_option_dict=feature_option_dict, **kwargs)
# append additional information to model before saving, so that
# VmafQualityRunner can read and process
model.append_info('feature_dict', feature_param.feature_dict) # need feature_dict so that VmafQualityRunner knows how to call FeatureAssembler
if 'score_clip' in model_param_dict:
VmafQualityRunner.set_clip_score(model, model_param_dict['score_clip'])
if 'score_transform' in model_param_dict:
VmafQualityRunner.set_transform_score(model, model_param_dict['score_transform'])
train_ys_pred = VmafQualityRunner.predict_with_model(model, train_xs, **kwargs)['ys_pred']
raw_groundtruths = None if train_raw_assets is None else \
list(map(lambda asset: asset.raw_groundtruth, train_raw_assets))
train_stats = model.get_stats(train_ys['label'], train_ys_pred, ys_label_raw=raw_groundtruths)
log = 'Stats on training data: {}'.format(model.format_stats_for_print(train_stats))
if logger:
logger.info(log)
else:
print(log)
# save model
if output_model_filepath is not None:
format = os.path.splitext(output_model_filepath)[1]
supported_formats = ['.pkl', '.json']
VmafQualityRunnerModelMixin._assert_extension_format(supported_formats, format)
if '.pkl' in format:
model.to_file(output_model_filepath, format='pkl')
elif '.json' in format:
model.to_file(output_model_filepath, format='json', combined=True)
else:
assert False
if train_ax is not None:
train_content_ids = list(map(lambda asset: asset.content_id, train_assets))
model_class.plot_scatter(train_ax, train_stats, content_ids=train_content_ids)
train_ax.set_xlabel('True Score')
train_ax.set_ylabel("Predicted Score")
train_ax.grid()
train_ax.set_title("Dataset: {dataset}, Model: {model}\n{stats}".format(
dataset=train_dataset.dataset_name,
model=model.model_id,
stats=model_class.format_stats_for_plot(train_stats)
))
# === test model on test dataset ===
if test_dataset is None:
test_assets = None
test_stats = None
test_fassembler = None
else:
test_assets = read_dataset(test_dataset, **kwargs)
test_raw_assets = None
try:
for test_asset in test_assets:
assert test_asset.groundtruth is not None
except AssertionError:
# no groundtruth, try to do subjective modeling
from sureal.dataset_reader import RawDatasetReader
from sureal.subjective_model import DmosModel
subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
subjective_model = subj_model_class(dataset_reader_class(test_dataset))
subjective_model.run_modeling(**kwargs)
test_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
test_raw_assets = test_assets
test_assets = read_dataset(test_dataset_aggregate, **kwargs)
test_fassembler = FeatureAssembler(
feature_dict=feature_dict,
feature_option_dict=feature_option_dict,
assets=test_assets,
logger=logger,
fifo_mode=fifo_mode,
delete_workdir=True,
result_store=result_store,
optional_dict=None, # WARNING: feature param not passed
optional_dict2=None,
parallelize=parallelize,
)
test_fassembler.run()
test_features = test_fassembler.results
for result in test_features:
result.set_score_aggregate_method(aggregate_method)
test_xs = model_class.get_xs_from_results(test_features)
test_ys = model_class.get_ys_from_results(test_features)
test_ys_pred = VmafQualityRunner.predict_with_model(model, test_xs, **kwargs)['ys_pred']
raw_groundtruths = None if test_raw_assets is None else \
list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))
test_stats = model.get_stats(test_ys['label'], test_ys_pred, ys_label_raw=raw_groundtruths)
log = 'Stats on testing data: {}'.format(model_class.format_stats_for_print(test_stats))
if logger:
logger.info(log)
else:
print(log)
if test_ax is not None:
test_content_ids = list(map(lambda asset: asset.content_id, test_assets))
model_class.plot_scatter(test_ax, test_stats, content_ids=test_content_ids)
test_ax.set_xlabel('True Score')
test_ax.set_ylabel("Predicted Score")
test_ax.grid()
test_ax.set_title("Dataset: {dataset}, Model: {model}\n{stats}".format(
dataset=test_dataset.dataset_name,
model=model.model_id,
stats=model_class.format_stats_for_plot(test_stats)
))
return train_fassembler, train_assets, train_stats, test_fassembler, test_assets, test_stats, model