def draw_colors()

in graspologic/plot/plot_matrix.py [0:0]


def draw_colors(ax, ax_type="x", meta=None, divider=None, color=None, palette="tab10"):
    r"""
    Draw colormap onto the axis to separate the data

    Parameters
    ----------
    ax : matplotlib axes object
        Axes in which to draw the colormap
    ax_type : char, optional
        Setting either the x or y axis, by default "x"
    meta : pd.DataFrame, pd.Series, list of pd.Series or np.array, optional
        Metadata of the matrix such as class, cell type, etc., by default None
    divider : AxesLocator, optional
        Divider used to add new axes to the plot
    color : str, list of str, or array_like, optional
        Attribute in meta by which to draw colorbars, by default None
    palette : str or dict, optional
        Colormap of the colorbar, by default "tab10"

    Returns
    -------
    ax : matplotlib axes object
        Axes in which to draw the color map
    """
    classes = meta[color]
    uni_classes = np.unique(classes)
    # Create the color dictionary
    if isinstance(palette, dict):
        color_dict = palette
    elif isinstance(palette, str):
        color_dict = dict(
            zip(uni_classes, sns.color_palette(palette, len(uni_classes)))
        )

    # Make the colormap
    class_map = dict(zip(uni_classes, range(len(uni_classes))))
    color_sorted = np.vectorize(color_dict.get)(uni_classes)
    color_sorted = np.array(color_sorted)
    if len(color_sorted) != len(uni_classes):
        color_sorted = color_sorted.T
    lc = ListedColormap(color_sorted)
    class_indicator = np.vectorize(class_map.get)(classes)

    if ax_type == "x":
        class_indicator = class_indicator.reshape(1, len(classes))
    elif ax_type == "y":
        class_indicator = class_indicator.reshape(len(classes), 1)
    sns.heatmap(
        class_indicator,
        cmap=lc,
        cbar=False,
        yticklabels=False,
        xticklabels=False,
        ax=ax,
        square=False,
    )
    if ax_type == "x":
        ax.set_xlabel(color, fontsize=20)
        ax.xaxis.set_label_position("top")
    elif ax_type == "y":
        ax.set_ylabel(color, fontsize=20)
    return ax