def train_test_vmaf_on_dataset()

in python/vmaf/routine.py [0:0]


def train_test_vmaf_on_dataset(train_dataset, test_dataset,
                               feature_param, model_param,
                               train_ax, test_ax, result_store,
                               logger=None, fifo_mode=True,
                               output_model_filepath=None,
                               aggregate_method=np.mean,
                               **kwargs):

    train_assets = read_dataset(train_dataset, **kwargs)
    train_raw_assets = None
    try:
        for train_asset in train_assets:
            assert train_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try to do subjective modeling
        from sureal.dataset_reader import RawDatasetReader
        from sureal.subjective_model import DmosModel
        subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
        dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
        subjective_model = subj_model_class(dataset_reader_class(train_dataset))
        subjective_model.run_modeling(**kwargs)
        train_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
        train_raw_assets = train_assets
        train_assets = read_dataset(train_dataset_aggregate, **kwargs)

    parallelize = kwargs['parallelize'] if 'parallelize' in kwargs else True
    isinstance(parallelize, bool)

    processes = kwargs['processes'] if 'processes' in kwargs else None
    if processes is not None:
        assert isinstance(processes, int) and processes > 0
    if processes is not None:
        assert parallelize is True, 'if processes is not None, parallelize must be True'

    assert hasattr(feature_param, 'feature_dict')
    feature_dict = feature_param.feature_dict
    feature_option_dict = feature_param.feature_optional_dict if hasattr(feature_param, 'feature_optional_dict') else None

    train_fassembler = FeatureAssembler(
        feature_dict=feature_dict,
        feature_option_dict=feature_option_dict,
        assets=train_assets,
        logger=logger,
        fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=None,  # WARNING: feature param not passed
        optional_dict2=None,
        parallelize=parallelize,
        processes=processes,
    )
    train_fassembler.run()
    train_features = train_fassembler.results

    for result in train_features:
        result.set_score_aggregate_method(aggregate_method)

    model_type = model_param.model_type
    model_param_dict = model_param.model_param_dict

    model_class = TrainTestModel.find_subclass(model_type)

    train_xys = model_class.get_xys_from_results(train_features)
    train_xs = model_class.get_xs_from_results(train_features)
    train_ys = model_class.get_ys_from_results(train_features)

    model = model_class(model_param_dict, logger)

    model.train(train_xys, feature_option_dict=feature_option_dict, **kwargs)

    # append additional information to model before saving, so that
    # VmafQualityRunner can read and process
    model.append_info('feature_dict', feature_param.feature_dict)  # need feature_dict so that VmafQualityRunner knows how to call FeatureAssembler
    if 'score_clip' in model_param_dict:
        VmafQualityRunner.set_clip_score(model, model_param_dict['score_clip'])
    if 'score_transform' in model_param_dict:
        VmafQualityRunner.set_transform_score(model, model_param_dict['score_transform'])

    train_ys_pred = VmafQualityRunner.predict_with_model(model, train_xs, **kwargs)['ys_pred']

    raw_groundtruths = None if train_raw_assets is None else \
        list(map(lambda asset: asset.raw_groundtruth, train_raw_assets))

    train_stats = model.get_stats(train_ys['label'], train_ys_pred, ys_label_raw=raw_groundtruths)

    log = 'Stats on training data: {}'.format(model.format_stats_for_print(train_stats))
    if logger:
        logger.info(log)
    else:
        print(log)

    # save model
    if output_model_filepath is not None:
        format = os.path.splitext(output_model_filepath)[1]
        supported_formats = ['.pkl', '.json']
        VmafQualityRunnerModelMixin._assert_extension_format(supported_formats, format)
        if '.pkl' in format:
            model.to_file(output_model_filepath, format='pkl')
        elif '.json' in format:
            model.to_file(output_model_filepath, format='json', combined=True)
        else:
            assert False

    if train_ax is not None:
        train_content_ids = list(map(lambda asset: asset.content_id, train_assets))
        model_class.plot_scatter(train_ax, train_stats, content_ids=train_content_ids)

        train_ax.set_xlabel('True Score')
        train_ax.set_ylabel("Predicted Score")
        train_ax.grid()
        train_ax.set_title("Dataset: {dataset}, Model: {model}\n{stats}".format(
            dataset=train_dataset.dataset_name,
            model=model.model_id,
            stats=model_class.format_stats_for_plot(train_stats)
        ))

    # === test model on test dataset ===

    if test_dataset is None:
        test_assets = None
        test_stats = None
        test_fassembler = None
    else:
        test_assets = read_dataset(test_dataset, **kwargs)
        test_raw_assets = None
        try:
            for test_asset in test_assets:
                assert test_asset.groundtruth is not None
        except AssertionError:
            # no groundtruth, try to do subjective modeling
            from sureal.dataset_reader import RawDatasetReader
            from sureal.subjective_model import DmosModel
            subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
            dataset_reader_class = kwargs['dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
            subjective_model = subj_model_class(dataset_reader_class(test_dataset))
            subjective_model.run_modeling(**kwargs)
            test_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
            test_raw_assets = test_assets
            test_assets = read_dataset(test_dataset_aggregate, **kwargs)

        test_fassembler = FeatureAssembler(
            feature_dict=feature_dict,
            feature_option_dict=feature_option_dict,
            assets=test_assets,
            logger=logger,
            fifo_mode=fifo_mode,
            delete_workdir=True,
            result_store=result_store,
            optional_dict=None,  # WARNING: feature param not passed
            optional_dict2=None,
            parallelize=parallelize,
        )
        test_fassembler.run()
        test_features = test_fassembler.results

        for result in test_features:
            result.set_score_aggregate_method(aggregate_method)

        test_xs = model_class.get_xs_from_results(test_features)
        test_ys = model_class.get_ys_from_results(test_features)

        test_ys_pred = VmafQualityRunner.predict_with_model(model, test_xs, **kwargs)['ys_pred']

        raw_groundtruths = None if test_raw_assets is None else \
            list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))

        test_stats = model.get_stats(test_ys['label'], test_ys_pred, ys_label_raw=raw_groundtruths)

        log = 'Stats on testing data: {}'.format(model_class.format_stats_for_print(test_stats))
        if logger:
            logger.info(log)
        else:
            print(log)

        if test_ax is not None:
            test_content_ids = list(map(lambda asset: asset.content_id, test_assets))
            model_class.plot_scatter(test_ax, test_stats, content_ids=test_content_ids)
            test_ax.set_xlabel('True Score')
            test_ax.set_ylabel("Predicted Score")
            test_ax.grid()
            test_ax.set_title("Dataset: {dataset}, Model: {model}\n{stats}".format(
                dataset=test_dataset.dataset_name,
                model=model.model_id,
                stats=model_class.format_stats_for_plot(test_stats)
            ))

    return train_fassembler, train_assets, train_stats, test_fassembler, test_assets, test_stats, model