def plot_matrix()

in msticpy/vis/matrix_plot.py [0:0]


def plot_matrix(data: pd.DataFrame, **kwargs) -> LayoutDOM:
    """
    Plot data as an intersection matrix.

    Parameters
    ----------
    data : pd.DataFrame
        The data to plot.
    x : str
        Column to plot on the x (horizontal) axis
    x_col : str
        Alias for 'x'
    y : str
        Column to plot on the y (vertical) axis
    y_col : str
        Alias for 'y'
    title : str, optional
        Custom title, default is 'Intersection plot'
    value_col : str, optional
        Column from the DataFrame used to size the intersection points.
    dist_count : bool, optional
        Calculates a count of distinct values (from `value_col`) and uses
        this to size the intersection points.
        Requires `value_col` to be specified.
    log_size : bool, optional
        Takes the log of the size value before calculating the intersection
        display point size.
        Can be combined with `invert`.
    invert : bool, optional
        Takes the inverse of the size value as the basis for calculating
        the intersection display point size. This is useful for highlighting
        rare interactions.
        Can be combined with `log_size`.
    intersect : bool, optional
        Plots points of a fixed size, rather than using a sizing value. This
        is useful for just showing the presence/absence of an interaction.
    height : int, optional
        The plot height. Default is 700
    width : int
        The plot width. Default is 900
    color : str
        The color of the plotted points, default is "red"
    sort : Union[str, bool], optional
        Sorts the labels of both axes, default is None.
        Acceptable values are:
        'asc' (or string starting with 'asc') - Sort ascending
        'desc' (or string starting with 'asc') - Sort descending
        False or None (no sort)
        True  - Sort ascending
    sort_x : str, optional
        Sorts the labels of the x axis (takes precedence over `sort`),
        default is None.
        Acceptable values are:
        'asc' (or string starting with 'asc') - Sort ascending
        'desc' (or string starting with 'asc') - Sort descending
        False or None (no sort)
        True  - Sort ascending
    sort_y : str, optional
        Sorts the labels of the y axis (takes precedence over `sort`),
        default is None.
        Acceptable values are:
        'asc' (or string starting with 'asc') - Sort ascending
        'desc' (or string starting with 'asc') - Sort descending
        False or None (no sort)
        True  - Sort ascending
    hide : bool, optional
        Creates and returns but does not display the plot, default
        is False.
    font_size : int, optional
        Manually specify the font size for axis labels, in points,
        the default is to automatically calculate a size based on the
        number of items in each axis.
    max_label_font_size : int, optional
        The maximum size, in points, of the X and Y labels, default is 11.


    Returns
    -------
    LayoutDOM
        The Bokeh plot

    """
    # Process/extract parameters
    check_kwargs(kwargs, PlotParams.field_list())
    param = PlotParams(**kwargs)

    if not param.x_column or not param.y_column:
        raise ValueError("Must supply `x` and `y` column parameters.")

    reset_output()
    output_notebook()

    plot_data = _prep_data(data, param)

    x_range = _sort_labels(plot_data, param.x_column, param.sort_x or param.sort)
    y_range = _sort_labels(
        plot_data, param.y_column, param.sort_y or param.sort, invert=True
    )

    # Rescale the size so that it matches the graph
    max_size = plot_data["size"].max()
    plot_data["plt_size"] = plot_data["size"] * 10 / max_size
    source = ColumnDataSource(data=plot_data)

    plot = figure(
        title=param.title,
        plot_width=param.width,
        plot_height=param.height,
        x_range=x_range,
        y_range=y_range,
        tools=["wheel_zoom", "box_zoom", "pan", "reset", "save"],
        toolbar_location="above",
    )

    tool_tips = [
        (param.x_column, f"@{param.x_column}"),
        (param.y_column, f"@{param.y_column}"),
        ("value", "@size"),
    ]
    plot.add_tools(HoverTool(tooltips=tool_tips))

    if param.intersect:
        plot.circle_cross(
            x=param.x_column,
            y=param.y_column,
            source=source,
            fill_alpha=0.6,
            line_color=param.color,
            size=5,
        )
    else:
        plot.circle(
            x=param.x_column,
            y=param.y_column,
            source=source,
            fill_alpha=0.6,
            fill_color=param.color,
            size="plt_size",
        )
    _set_plot_params(plot)

    # Calculate appropriate font size for labels
    x_label_pt_size = param.font_size or max(
        5,
        min(
            param.max_label_font_size,
            int(param.width * 0.6 / plot_data[param.x_column].nunique()),
        ),
    )
    y_label_pt_size = param.font_size or max(
        5,
        min(
            param.max_label_font_size,
            int(param.height * 0.6 / plot_data[param.y_column].nunique()),
        ),
    )
    plot.xaxis.major_label_text_font_size = f"{x_label_pt_size}pt"
    plot.yaxis.major_label_text_font_size = f"{y_label_pt_size}pt"
    plot.xaxis.axis_label = param.x_column
    plot.yaxis.axis_label = param.y_column

    if not param.hide:
        show(plot)
    return plot