in sample_info/modules/visualizations.py [0:0]
def plot_summary_of_informativeness(data, informativeness_scores, label_names=None, save_name=None,
plt=None, is_label_one_hot=False, **kwargs):
"""
:param informativeness_scores: np.ndarray of informativeness scores
"""
if plt is None:
_, plt = import_matplotlib(agg=True, use_style=False)
informativeness_scores = convert_scores_to_numpy(informativeness_scores)
fig = plt.figure(constrained_layout=True, figsize=(24, 5))
gs = fig.add_gridspec(22, 13)
ax_left = fig.add_subplot(gs[1:22, :3])
ys = [torch.tensor(y) for x, y in data]
if is_label_one_hot:
ys = [torch.argmax(y) for y in ys]
ys = np.array([y.item() for y in ys])
set_ys = sorted(list(set(ys)))
for y in set_ys:
mask = (ys == y)
label = str(y)
if label_names is not None:
label = label_names[y]
ax_left.hist(informativeness_scores[mask], bins=30, label=label, alpha=0.6)
ax_left.legend()
ax_left.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
ax_left.set_xlabel('Informativeness of an example')
ax_left.set_ylabel('Count')
ax_left.legend()
x_pos = ax_left.get_position().x0
y_pos = ax_left.get_position().y1
fig.text(x_pos - 0.05, y_pos + 0.065, 'A', size=28, weight='bold')
order = np.argsort(informativeness_scores)
least_informative = order[:10]
most_informative = order[-10:]
for i in range(10):
ax = fig.add_subplot(gs[1:11, 3 + i])
if i == 0:
x_pos = ax.get_position().x0
y_pos = ax.get_position().y1
fig.text(x_pos - 0.05, y_pos + 0.065, 'B', size=28, weight='bold')
x, y = data[least_informative[i]]
x = revert_normalization(x, data)[0]
x = utils.to_numpy(x)
x = get_image(x)
ax.imshow(x, vmin=0, vmax=1)
ax.set_axis_off()
for i in range(10):
ax = fig.add_subplot(gs[12:22, 3 + i])
if i == 0:
x_pos = ax.get_position().x0
y_pos = ax.get_position().y1
fig.text(x_pos - 0.05, y_pos + 0.042, 'C', size=28, weight='bold')
x, y = data[most_informative[i]]
x = revert_normalization(x, data)[0]
x = utils.to_numpy(x)
x = get_image(x)
ax.imshow(x, vmin=0, vmax=1)
ax.set_axis_off()
if save_name is not None:
savefig(fig, save_name)
return fig, None