def plot_summary_of_informativeness()

in sample_info/modules/visualizations.py [0:0]


def plot_summary_of_informativeness(data, informativeness_scores, label_names=None, save_name=None,
                                    plt=None, is_label_one_hot=False, **kwargs):
    """
    :param informativeness_scores: np.ndarray of informativeness scores
    """
    if plt is None:
        _, plt = import_matplotlib(agg=True, use_style=False)

    informativeness_scores = convert_scores_to_numpy(informativeness_scores)

    fig = plt.figure(constrained_layout=True, figsize=(24, 5))
    gs = fig.add_gridspec(22, 13)
    ax_left = fig.add_subplot(gs[1:22, :3])

    ys = [torch.tensor(y) for x, y in data]
    if is_label_one_hot:
        ys = [torch.argmax(y) for y in ys]
    ys = np.array([y.item() for y in ys])

    set_ys = sorted(list(set(ys)))

    for y in set_ys:
        mask = (ys == y)
        label = str(y)
        if label_names is not None:
            label = label_names[y]
        ax_left.hist(informativeness_scores[mask], bins=30, label=label, alpha=0.6)
    ax_left.legend()
    ax_left.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
    ax_left.set_xlabel('Informativeness of an example')
    ax_left.set_ylabel('Count')
    ax_left.legend()
    x_pos = ax_left.get_position().x0
    y_pos = ax_left.get_position().y1
    fig.text(x_pos - 0.05, y_pos + 0.065, 'A', size=28, weight='bold')

    order = np.argsort(informativeness_scores)
    least_informative = order[:10]
    most_informative = order[-10:]

    for i in range(10):
        ax = fig.add_subplot(gs[1:11, 3 + i])
        if i == 0:
            x_pos = ax.get_position().x0
            y_pos = ax.get_position().y1
            fig.text(x_pos - 0.05, y_pos + 0.065, 'B', size=28, weight='bold')

        x, y = data[least_informative[i]]
        x = revert_normalization(x, data)[0]
        x = utils.to_numpy(x)
        x = get_image(x)
        ax.imshow(x, vmin=0, vmax=1)
        ax.set_axis_off()

    for i in range(10):
        ax = fig.add_subplot(gs[12:22, 3 + i])
        if i == 0:
            x_pos = ax.get_position().x0
            y_pos = ax.get_position().y1
            fig.text(x_pos - 0.05, y_pos + 0.042, 'C', size=28, weight='bold')

        x, y = data[most_informative[i]]
        x = revert_normalization(x, data)[0]
        x = utils.to_numpy(x)
        x = get_image(x)
        ax.imshow(x, vmin=0, vmax=1)
        ax.set_axis_off()

    if save_name is not None:
        savefig(fig, save_name)

    return fig, None