def plot_partial_pooling_model()

in src/beanmachine/tutorials/utils/baseball.py [0:0]


def plot_partial_pooling_model(samples: MonteCarloSamples, df: pd.DataFrame) -> Figure:
    """
    Partial-pooling model plot.

    :param samples: Bean Machine inference object.
    :type samples: MonteCarloSamples
    :param df: Dataframe of the model data.
    :type df: pd.DataFrame
    :return: Bokeh figure of the partial-pooling model.
    :rtype: Figure
    """
    # Prepare data for the figure.
    diagnostics_data = _sample_data_prep(samples, df)
    hdi_df = az.hdi(diagnostics_data, hdi_prob=0.89).to_dataframe()
    hdi_df = hdi_df.T.rename(columns={"lower": "hdi_11%", "higher": "hdi_89%"})
    summary_df = az.summary(diagnostics_data, round_to=4).join(hdi_df)
    theta_index = summary_df[
        summary_df.index.astype(str).str.contains("θ")
    ].index.values
    x = (df["Current hits"] / df["Current at-bats"]).values
    y = summary_df.loc[theta_index, "mean"]
    upper_hdi = summary_df.loc[theta_index, "hdi_89%"]
    lower_hdi = summary_df.loc[theta_index, "hdi_11%"]
    population_mean = (df["Current hits"] / df["Current at-bats"]).mean()

    # Create the figure data source.
    source = ColumnDataSource(
        {
            "x": x,
            "y": y,
            "upper_hdi": upper_hdi,
            "lower_hdi": lower_hdi,
            "name": df["Name"].values,
        }
    )

    # Create the figure.
    plot = figure(
        plot_width=500,
        plot_height=500,
        title="Partial pooling",
        x_axis_label="Observed hits / at-bats",
        y_axis_label="Predicted chance of a hit",
        x_range=[0.14, 0.41],
        y_range=[0.05, 0.55],
    )

    # Add the empirical mean at-bat hit chance to the figure.
    plot.line(
        x=[0, 1],
        y=[population_mean, population_mean],
        line_color="orange",
        line_width=3,
        level="underlay",
        legend_label="Population mean",
    )

    # Add the standard deviation of the mean at-bat hit chance to the figure.
    std_band = Band(
        base="x",
        lower="lower_std",
        upper="upper_std",
        source=source,
        level="underlay",
        fill_alpha=0.2,
        fill_color="orange",
        line_width=0.2,
        line_color="orange",
    )
    plot.add_layout(std_band)

    # Add the empirical at-bat hit chance to the figure.
    plot.line(
        x=x,
        y=(df["Current hits"] / df["Current at-bats"]).values,
        line_color="grey",
        line_alpha=0.7,
        line_width=2.0,
        legend_label="Current hits / Current at-bats",
    )

    # Add the HDI whiskers to the figure.
    whiskers = Whisker(
        base="x",
        upper="upper_hdi",
        lower="lower_hdi",
        source=source,
        line_color="steelblue",
    )
    whiskers.upper_head.line_color = "steelblue"
    whiskers.lower_head.line_color = "steelblue"
    plot.add_layout(whiskers)

    # Add the partial-pooling model data to the figure.
    glyph = plot.circle(
        x="x",
        y="y",
        source=source,
        size=10,
        line_color="white",
        fill_color="steelblue",
        legend_label="Players",
    )
    tooltips = HoverTool(
        renderers=[glyph],
        tooltips=[
            ("Name", "@name"),
            ("Posterior Upper HDI", "@upper_hdi{0.000}"),
            ("Posterior Mode", "@mode{0.000}"),
            ("Posterior Lower HDI", "@lower_hdi{0.000}"),
        ],
    )
    plot.add_tools(tooltips)

    # Add a legend to the figure.
    plot.legend.location = "top_left"
    plot.legend.click_policy = "mute"

    # Style the figure.
    plots.style(plot)

    return plot