in src/beanmachine/tutorials/utils/baseball.py [0:0]
def plot_partial_pooling_model(samples: MonteCarloSamples, df: pd.DataFrame) -> Figure:
"""
Partial-pooling model plot.
:param samples: Bean Machine inference object.
:type samples: MonteCarloSamples
:param df: Dataframe of the model data.
:type df: pd.DataFrame
:return: Bokeh figure of the partial-pooling model.
:rtype: Figure
"""
# Prepare data for the figure.
diagnostics_data = _sample_data_prep(samples, df)
hdi_df = az.hdi(diagnostics_data, hdi_prob=0.89).to_dataframe()
hdi_df = hdi_df.T.rename(columns={"lower": "hdi_11%", "higher": "hdi_89%"})
summary_df = az.summary(diagnostics_data, round_to=4).join(hdi_df)
theta_index = summary_df[
summary_df.index.astype(str).str.contains("θ")
].index.values
x = (df["Current hits"] / df["Current at-bats"]).values
y = summary_df.loc[theta_index, "mean"]
upper_hdi = summary_df.loc[theta_index, "hdi_89%"]
lower_hdi = summary_df.loc[theta_index, "hdi_11%"]
population_mean = (df["Current hits"] / df["Current at-bats"]).mean()
# Create the figure data source.
source = ColumnDataSource(
{
"x": x,
"y": y,
"upper_hdi": upper_hdi,
"lower_hdi": lower_hdi,
"name": df["Name"].values,
}
)
# Create the figure.
plot = figure(
plot_width=500,
plot_height=500,
title="Partial pooling",
x_axis_label="Observed hits / at-bats",
y_axis_label="Predicted chance of a hit",
x_range=[0.14, 0.41],
y_range=[0.05, 0.55],
)
# Add the empirical mean at-bat hit chance to the figure.
plot.line(
x=[0, 1],
y=[population_mean, population_mean],
line_color="orange",
line_width=3,
level="underlay",
legend_label="Population mean",
)
# Add the standard deviation of the mean at-bat hit chance to the figure.
std_band = Band(
base="x",
lower="lower_std",
upper="upper_std",
source=source,
level="underlay",
fill_alpha=0.2,
fill_color="orange",
line_width=0.2,
line_color="orange",
)
plot.add_layout(std_band)
# Add the empirical at-bat hit chance to the figure.
plot.line(
x=x,
y=(df["Current hits"] / df["Current at-bats"]).values,
line_color="grey",
line_alpha=0.7,
line_width=2.0,
legend_label="Current hits / Current at-bats",
)
# Add the HDI whiskers to the figure.
whiskers = Whisker(
base="x",
upper="upper_hdi",
lower="lower_hdi",
source=source,
line_color="steelblue",
)
whiskers.upper_head.line_color = "steelblue"
whiskers.lower_head.line_color = "steelblue"
plot.add_layout(whiskers)
# Add the partial-pooling model data to the figure.
glyph = plot.circle(
x="x",
y="y",
source=source,
size=10,
line_color="white",
fill_color="steelblue",
legend_label="Players",
)
tooltips = HoverTool(
renderers=[glyph],
tooltips=[
("Name", "@name"),
("Posterior Upper HDI", "@upper_hdi{0.000}"),
("Posterior Mode", "@mode{0.000}"),
("Posterior Lower HDI", "@lower_hdi{0.000}"),
],
)
plot.add_tools(tooltips)
# Add a legend to the figure.
plot.legend.location = "top_left"
plot.legend.click_policy = "mute"
# Style the figure.
plots.style(plot)
return plot