in msticpy/vis/matrix_plot.py [0:0]
def plot_matrix(data: pd.DataFrame, **kwargs) -> LayoutDOM:
"""
Plot data as an intersection matrix.
Parameters
----------
data : pd.DataFrame
The data to plot.
x : str
Column to plot on the x (horizontal) axis
x_col : str
Alias for 'x'
y : str
Column to plot on the y (vertical) axis
y_col : str
Alias for 'y'
title : str, optional
Custom title, default is 'Intersection plot'
value_col : str, optional
Column from the DataFrame used to size the intersection points.
dist_count : bool, optional
Calculates a count of distinct values (from `value_col`) and uses
this to size the intersection points.
Requires `value_col` to be specified.
log_size : bool, optional
Takes the log of the size value before calculating the intersection
display point size.
Can be combined with `invert`.
invert : bool, optional
Takes the inverse of the size value as the basis for calculating
the intersection display point size. This is useful for highlighting
rare interactions.
Can be combined with `log_size`.
intersect : bool, optional
Plots points of a fixed size, rather than using a sizing value. This
is useful for just showing the presence/absence of an interaction.
height : int, optional
The plot height. Default is 700
width : int
The plot width. Default is 900
color : str
The color of the plotted points, default is "red"
sort : Union[str, bool], optional
Sorts the labels of both axes, default is None.
Acceptable values are:
'asc' (or string starting with 'asc') - Sort ascending
'desc' (or string starting with 'asc') - Sort descending
False or None (no sort)
True - Sort ascending
sort_x : str, optional
Sorts the labels of the x axis (takes precedence over `sort`),
default is None.
Acceptable values are:
'asc' (or string starting with 'asc') - Sort ascending
'desc' (or string starting with 'asc') - Sort descending
False or None (no sort)
True - Sort ascending
sort_y : str, optional
Sorts the labels of the y axis (takes precedence over `sort`),
default is None.
Acceptable values are:
'asc' (or string starting with 'asc') - Sort ascending
'desc' (or string starting with 'asc') - Sort descending
False or None (no sort)
True - Sort ascending
hide : bool, optional
Creates and returns but does not display the plot, default
is False.
font_size : int, optional
Manually specify the font size for axis labels, in points,
the default is to automatically calculate a size based on the
number of items in each axis.
max_label_font_size : int, optional
The maximum size, in points, of the X and Y labels, default is 11.
Returns
-------
LayoutDOM
The Bokeh plot
"""
# Process/extract parameters
check_kwargs(kwargs, PlotParams.field_list())
param = PlotParams(**kwargs)
if not param.x_column or not param.y_column:
raise ValueError("Must supply `x` and `y` column parameters.")
reset_output()
output_notebook()
plot_data = _prep_data(data, param)
x_range = _sort_labels(plot_data, param.x_column, param.sort_x or param.sort)
y_range = _sort_labels(
plot_data, param.y_column, param.sort_y or param.sort, invert=True
)
# Rescale the size so that it matches the graph
max_size = plot_data["size"].max()
plot_data["plt_size"] = plot_data["size"] * 10 / max_size
source = ColumnDataSource(data=plot_data)
plot = figure(
title=param.title,
plot_width=param.width,
plot_height=param.height,
x_range=x_range,
y_range=y_range,
tools=["wheel_zoom", "box_zoom", "pan", "reset", "save"],
toolbar_location="above",
)
tool_tips = [
(param.x_column, f"@{param.x_column}"),
(param.y_column, f"@{param.y_column}"),
("value", "@size"),
]
plot.add_tools(HoverTool(tooltips=tool_tips))
if param.intersect:
plot.circle_cross(
x=param.x_column,
y=param.y_column,
source=source,
fill_alpha=0.6,
line_color=param.color,
size=5,
)
else:
plot.circle(
x=param.x_column,
y=param.y_column,
source=source,
fill_alpha=0.6,
fill_color=param.color,
size="plt_size",
)
_set_plot_params(plot)
# Calculate appropriate font size for labels
x_label_pt_size = param.font_size or max(
5,
min(
param.max_label_font_size,
int(param.width * 0.6 / plot_data[param.x_column].nunique()),
),
)
y_label_pt_size = param.font_size or max(
5,
min(
param.max_label_font_size,
int(param.height * 0.6 / plot_data[param.y_column].nunique()),
),
)
plot.xaxis.major_label_text_font_size = f"{x_label_pt_size}pt"
plot.yaxis.major_label_text_font_size = f"{y_label_pt_size}pt"
plot.xaxis.axis_label = param.x_column
plot.yaxis.axis_label = param.y_column
if not param.hide:
show(plot)
return plot