in python/vmaf/core/train_test_model.py [0:0]
def plot_scatter(cls, ax, stats, **kwargs):
ys_label = np.array(stats['ys_label'])
ys_label_pred = np.array(stats['ys_label_pred'])
assert len(ys_label_pred) == len(ys_label)
ys_label_stddev = np.array(stats['ys_label_stddev']) if 'ys_label_stddev' in stats else np.zeros(len(ys_label)) # FIXME: setting std to 0 may be misleading
try:
ys_label_stddev[np.isnan(ys_label_stddev)] = 0
except:
ys_label_stddev[ys_label_stddev == None] = 0
assert len(ys_label_stddev) == len(ys_label)
xlim, ylim = cls.get_xlim_ylim(ys_label, ys_label_pred, ys_label_stddev)
content_ids = kwargs['content_ids'] if 'content_ids' in kwargs else None
if content_ids is not None:
assert len(content_ids) == len(ys_label)
point_labels = kwargs['point_labels'] if 'point_labels' in kwargs else None
if point_labels is not None:
assert len(point_labels) == len(ys_label)
plot_linear_fit = kwargs['plot_linear_fit'] if 'plot_linear_fit' in kwargs else False
assert isinstance(plot_linear_fit, bool)
do_plot = kwargs['do_plot'] if 'do_plot' in kwargs else ['aggregate']
accepted_options = ['aggregate', 'per_content', 'groundtruth_predicted_in_parallel']
assert isinstance(do_plot, list), f"do_plot needs to be a list of plotting options. Accepted options are " \
f"{accepted_options}"
for option in do_plot:
assert option in accepted_options, f"{option} is not in {accepted_options}"
overall_linear_fit = None
if plot_linear_fit:
overall_linear_fit = linear_fit(ys_label, ys_label_pred)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.axline((xlim[0], linear_func(xlim[0], overall_linear_fit[0][0], overall_linear_fit[0][1])),
(xlim[1], linear_func(xlim[1], overall_linear_fit[0][0], overall_linear_fit[0][1])),
color='gray', linestyle='--')
ax.legend(['overall fit'])
if 'groundtruth_predicted_in_parallel' in do_plot:
w, h = _get_plot_width_and_height(len(ys_label))
fig_gt_pred, [ax_groundtruth, ax_predicted] = plt.subplots(figsize=[w, h * 2], ncols=1, nrows=2)
ax_groundtruth.set_xlabel('Stimuli')
ax_groundtruth.set_ylabel('True Score')
ax_groundtruth.grid()
ax_groundtruth.set_ylim(xlim)
ax_predicted.set_xlabel('Stimuli')
ax_predicted.set_ylabel('Predicted Score')
ax_predicted.grid()
ax_predicted.set_ylim(ylim)
else:
fig_gt_pred = None
ax_groundtruth = None
ax_predicted = None
if content_ids is None:
ax.errorbar(ys_label, ys_label_pred, xerr=1.96 * ys_label_stddev, marker='o', linestyle='')
else:
assert len(ys_label) == len(content_ids)
unique_content_ids = list(set(content_ids))
cmap = plt.get_cmap('jet')
colors = [cmap(i) for i in np.linspace(0, 1, len(unique_content_ids))]
for idx, curr_content_id in enumerate(unique_content_ids):
curr_idxs = indices(content_ids, lambda cid: cid == curr_content_id)
curr_ys_label = ys_label[curr_idxs]
curr_ys_label_pred = ys_label_pred[curr_idxs]
curr_ys_label_stddev = ys_label_stddev[curr_idxs]
if 'aggregate' in do_plot:
ax.errorbar(curr_ys_label, curr_ys_label_pred,
xerr=1.96 * curr_ys_label_stddev,
marker='o', linestyle='', label=curr_content_id, color=colors[idx % len(colors)])
if 'per_content' in do_plot:
new_fig, new_ax = plt.subplots(1, 1, figsize=ax.figure.get_size_inches())
new_ax.update_from(ax)
new_ax.set_xlim(xlim)
new_ax.set_ylim(ylim)
if plot_linear_fit:
curr_linear_fit = linear_fit(curr_ys_label, curr_ys_label_pred)
new_ax.axline(
(xlim[0], linear_func(xlim[0], overall_linear_fit[0][0], overall_linear_fit[0][1])),
(xlim[1], linear_func(xlim[1], overall_linear_fit[0][0], overall_linear_fit[0][1])),
color='gray', linestyle='--'
)
new_ax.axline(
(xlim[0], linear_func(xlim[0], curr_linear_fit[0][0], curr_linear_fit[0][1])),
(xlim[1], linear_func(xlim[1], curr_linear_fit[0][0], curr_linear_fit[0][1])),
color='red', linestyle='--'
)
new_ax.legend(['overall fit', 'current fit'])
new_ax.errorbar(
curr_ys_label, curr_ys_label_pred,
xerr=1.96 * curr_ys_label_stddev,
marker='o', linestyle='', label=curr_content_id, color=colors[idx % len(colors)]
)
new_ax.set_title(f'Content id {str(curr_content_id)}')
new_ax.set_xlabel('True Score')
new_ax.set_ylabel("Predicted Score")
new_ax.grid()
if point_labels:
curr_point_labels = np.array(point_labels)[curr_idxs]
assert len(curr_point_labels) == len(curr_ys_label)
for i, curr_point_label in enumerate(curr_point_labels):
new_ax.annotate(curr_point_label, (curr_ys_label[i], curr_ys_label_pred[i]))
if 'groundtruth_predicted_in_parallel' in do_plot:
ax_groundtruth.plot(curr_idxs, curr_ys_label, '-^', color=colors[idx % len(colors)], label=f'Content id {str(curr_content_id)}')
ax_predicted.plot(curr_idxs, curr_ys_label_pred, '-^', color=colors[idx % len(colors)], label=f'Content id {str(curr_content_id)}')
if 'aggregate' in do_plot and point_labels is not None:
assert len(point_labels) == len(ys_label)
for i, point_label in enumerate(point_labels):
ax.annotate(point_label, (ys_label[i], ys_label_pred[i]))
# need the following because ax is passed anyway; if not used for aggregate, close it
# TODO: can be improved
if 'aggregate' not in do_plot:
plt.close(ax.figure)
if 'groundtruth_predicted_in_parallel' in do_plot:
ax_groundtruth.legend()
fig_gt_pred.tight_layout()