in src/graphing/heatmap.py [0:0]
def __heatmap_make(data, row_labels, col_labels, ax=None, cbar_kw={}, cbarlabel="", fig = None, **kwargs):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (N, M).
row_labels
A list or array of length N with the labels for the rows.
col_labels
A list or array of length M with the labels for the columns.
ax
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
not provided, use current axes or create a new one. Optional.
cbar_kw
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
cbarlabel
The label for the colorbar. Optional.
**kwargs
All other arguments are forwarded to `imshow`.
"""
if not ax:
ax = plt.gca()
# Plot the heatmap
threshold = 0.5
im = ax.imshow(data, vmin=threshold, **kwargs)
# Create colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.35)
# Create the colorbar using the data, and the located divider
cbar = plt.colorbar(im, cax=cax, extend='min')
# Add minimum and maximum values to colorbar tick label
yticks = list(cbar.ax.get_yticks())
if yticks[0] != 0:
yticks.insert(0, 0)
if cbar.ax.get_ylim()[1] != yticks[-1]:
yticks.append(cbar.ax.get_ylim()[1])
cbar.set_ticks(yticks)
cmap = matplotlib.cm.get_cmap("plasma")
cmap.set_under('white')
cax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="left", rotation_mode="anchor")
# Turn spines off and create white grid.
ax.spines[:].set_visible(False)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
return im