in graspologic/plot/plot_matrix.py [0:0]
def draw_colors(ax, ax_type="x", meta=None, divider=None, color=None, palette="tab10"):
r"""
Draw colormap onto the axis to separate the data
Parameters
----------
ax : matplotlib axes object
Axes in which to draw the colormap
ax_type : char, optional
Setting either the x or y axis, by default "x"
meta : pd.DataFrame, pd.Series, list of pd.Series or np.array, optional
Metadata of the matrix such as class, cell type, etc., by default None
divider : AxesLocator, optional
Divider used to add new axes to the plot
color : str, list of str, or array_like, optional
Attribute in meta by which to draw colorbars, by default None
palette : str or dict, optional
Colormap of the colorbar, by default "tab10"
Returns
-------
ax : matplotlib axes object
Axes in which to draw the color map
"""
classes = meta[color]
uni_classes = np.unique(classes)
# Create the color dictionary
if isinstance(palette, dict):
color_dict = palette
elif isinstance(palette, str):
color_dict = dict(
zip(uni_classes, sns.color_palette(palette, len(uni_classes)))
)
# Make the colormap
class_map = dict(zip(uni_classes, range(len(uni_classes))))
color_sorted = np.vectorize(color_dict.get)(uni_classes)
color_sorted = np.array(color_sorted)
if len(color_sorted) != len(uni_classes):
color_sorted = color_sorted.T
lc = ListedColormap(color_sorted)
class_indicator = np.vectorize(class_map.get)(classes)
if ax_type == "x":
class_indicator = class_indicator.reshape(1, len(classes))
elif ax_type == "y":
class_indicator = class_indicator.reshape(len(classes), 1)
sns.heatmap(
class_indicator,
cmap=lc,
cbar=False,
yticklabels=False,
xticklabels=False,
ax=ax,
square=False,
)
if ax_type == "x":
ax.set_xlabel(color, fontsize=20)
ax.xaxis.set_label_position("top")
elif ax_type == "y":
ax.set_ylabel(color, fontsize=20)
return ax