def __heatmap_make()

in src/graphing/heatmap.py [0:0]


def __heatmap_make(data, row_labels, col_labels, ax=None, cbar_kw={}, cbarlabel="", fig = None, **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data
        A 2D numpy array of shape (N, M).
    row_labels
        A list or array of length N with the labels for the rows.
    col_labels
        A list or array of length M with the labels for the columns.
    ax
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current axes or create a new one.  Optional.
    cbar_kw
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
    cbarlabel
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """

    if not ax:
        ax = plt.gca()

    # Plot the heatmap
    threshold = 0.5
    im = ax.imshow(data,  vmin=threshold, **kwargs)

    # Create colorbar    
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.35)
    # Create the colorbar using the data, and the located divider
    cbar = plt.colorbar(im, cax=cax, extend='min')

    # Add minimum and maximum values to colorbar tick label
    yticks = list(cbar.ax.get_yticks())
    if yticks[0] != 0:
        yticks.insert(0, 0)
    if cbar.ax.get_ylim()[1] != yticks[-1]:
        yticks.append(cbar.ax.get_ylim()[1])
    cbar.set_ticks(yticks)
    cmap = matplotlib.cm.get_cmap("plasma")
    cmap.set_under('white')
    
    

    cax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # We want to show all ticks...
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticks(np.arange(data.shape[0]))
    # ... and label them with the respective list entries.
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="left", rotation_mode="anchor")

    # Turn spines off and create white grid.
    ax.spines[:].set_visible(False)
    ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
    ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)
    
    return im