def plot_scatter()

in python/vmaf/core/train_test_model.py [0:0]


    def plot_scatter(cls, ax, stats, **kwargs):

        ys_label = np.array(stats['ys_label'])
        ys_label_pred = np.array(stats['ys_label_pred'])
        assert len(ys_label_pred) == len(ys_label)

        ys_label_stddev = np.array(stats['ys_label_stddev']) if 'ys_label_stddev' in stats else np.zeros(len(ys_label))  # FIXME: setting std to 0 may be misleading
        try:
            ys_label_stddev[np.isnan(ys_label_stddev)] = 0
        except:
            ys_label_stddev[ys_label_stddev == None] = 0
        assert len(ys_label_stddev) == len(ys_label)

        xlim, ylim = cls.get_xlim_ylim(ys_label, ys_label_pred, ys_label_stddev)
        content_ids = kwargs['content_ids'] if 'content_ids' in kwargs else None
        if content_ids is not None:
            assert len(content_ids) == len(ys_label)
        point_labels = kwargs['point_labels'] if 'point_labels' in kwargs else None
        if point_labels is not None:
            assert len(point_labels) == len(ys_label)
        plot_linear_fit = kwargs['plot_linear_fit'] if 'plot_linear_fit' in kwargs else False
        assert isinstance(plot_linear_fit, bool)

        do_plot = kwargs['do_plot'] if 'do_plot' in kwargs else ['aggregate']
        accepted_options = ['aggregate', 'per_content', 'groundtruth_predicted_in_parallel']
        assert isinstance(do_plot, list), f"do_plot needs to be a list of plotting options. Accepted options are " \
                                          f"{accepted_options}"

        for option in do_plot:
            assert option in accepted_options, f"{option} is not in {accepted_options}"

        overall_linear_fit = None
        if plot_linear_fit:
            overall_linear_fit = linear_fit(ys_label, ys_label_pred)
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            ax.axline((xlim[0], linear_func(xlim[0], overall_linear_fit[0][0], overall_linear_fit[0][1])),
                      (xlim[1], linear_func(xlim[1], overall_linear_fit[0][0], overall_linear_fit[0][1])),
                      color='gray', linestyle='--')
            ax.legend(['overall fit'])

        if 'groundtruth_predicted_in_parallel' in do_plot:
            w, h = _get_plot_width_and_height(len(ys_label))
            fig_gt_pred, [ax_groundtruth, ax_predicted] = plt.subplots(figsize=[w, h * 2], ncols=1, nrows=2)
            ax_groundtruth.set_xlabel('Stimuli')
            ax_groundtruth.set_ylabel('True Score')
            ax_groundtruth.grid()
            ax_groundtruth.set_ylim(xlim)
            ax_predicted.set_xlabel('Stimuli')
            ax_predicted.set_ylabel('Predicted Score')
            ax_predicted.grid()
            ax_predicted.set_ylim(ylim)
        else:
            fig_gt_pred = None
            ax_groundtruth = None
            ax_predicted = None

        if content_ids is None:
            ax.errorbar(ys_label, ys_label_pred, xerr=1.96 * ys_label_stddev, marker='o', linestyle='')
        else:
            assert len(ys_label) == len(content_ids)

            unique_content_ids = list(set(content_ids))
            cmap = plt.get_cmap('jet')
            colors = [cmap(i) for i in np.linspace(0, 1, len(unique_content_ids))]
            for idx, curr_content_id in enumerate(unique_content_ids):
                curr_idxs = indices(content_ids, lambda cid: cid == curr_content_id)
                curr_ys_label = ys_label[curr_idxs]
                curr_ys_label_pred = ys_label_pred[curr_idxs]

                curr_ys_label_stddev = ys_label_stddev[curr_idxs]
                if 'aggregate' in do_plot:
                    ax.errorbar(curr_ys_label, curr_ys_label_pred,
                                xerr=1.96 * curr_ys_label_stddev,
                                marker='o', linestyle='', label=curr_content_id, color=colors[idx % len(colors)])

                if 'per_content' in do_plot:
                    new_fig, new_ax = plt.subplots(1, 1, figsize=ax.figure.get_size_inches())
                    new_ax.update_from(ax)
                    new_ax.set_xlim(xlim)
                    new_ax.set_ylim(ylim)

                    if plot_linear_fit:
                        curr_linear_fit = linear_fit(curr_ys_label, curr_ys_label_pred)
                        new_ax.axline(
                            (xlim[0], linear_func(xlim[0], overall_linear_fit[0][0], overall_linear_fit[0][1])),
                            (xlim[1], linear_func(xlim[1], overall_linear_fit[0][0], overall_linear_fit[0][1])),
                            color='gray', linestyle='--'
                        )
                        new_ax.axline(
                            (xlim[0], linear_func(xlim[0], curr_linear_fit[0][0], curr_linear_fit[0][1])),
                            (xlim[1], linear_func(xlim[1], curr_linear_fit[0][0], curr_linear_fit[0][1])),
                            color='red', linestyle='--'
                        )
                        new_ax.legend(['overall fit', 'current fit'])

                    new_ax.errorbar(
                        curr_ys_label, curr_ys_label_pred,
                        xerr=1.96 * curr_ys_label_stddev,
                        marker='o', linestyle='', label=curr_content_id, color=colors[idx % len(colors)]
                    )

                    new_ax.set_title(f'Content id {str(curr_content_id)}')
                    new_ax.set_xlabel('True Score')
                    new_ax.set_ylabel("Predicted Score")
                    new_ax.grid()

                    if point_labels:
                        curr_point_labels = np.array(point_labels)[curr_idxs]
                        assert len(curr_point_labels) == len(curr_ys_label)
                        for i, curr_point_label in enumerate(curr_point_labels):
                            new_ax.annotate(curr_point_label, (curr_ys_label[i], curr_ys_label_pred[i]))

                if 'groundtruth_predicted_in_parallel' in do_plot:
                    ax_groundtruth.plot(curr_idxs, curr_ys_label, '-^', color=colors[idx % len(colors)], label=f'Content id {str(curr_content_id)}')
                    ax_predicted.plot(curr_idxs, curr_ys_label_pred, '-^', color=colors[idx % len(colors)], label=f'Content id {str(curr_content_id)}')

        if 'aggregate' in do_plot and point_labels is not None:
            assert len(point_labels) == len(ys_label)
            for i, point_label in enumerate(point_labels):
                ax.annotate(point_label, (ys_label[i], ys_label_pred[i]))

        # need the following because ax is passed anyway; if not used for aggregate, close it
        # TODO: can be improved
        if 'aggregate' not in do_plot:
            plt.close(ax.figure)

        if 'groundtruth_predicted_in_parallel' in do_plot:
            ax_groundtruth.legend()
            fig_gt_pred.tight_layout()