def make_bar_plot()

in evals/elsuite/track_the_stat/scripts/make_plots.py [0:0]


def make_bar_plot(results_dict: dict, task: str, stat: str, save_path: Path):
    sns.set_context("paper")
    sns.set_style("whitegrid")

    data = results_dict[stat][task]

    # the random baseline and human baseline aren't plotted as bars
    models = MODELS[:-2]

    state_tracking_kinds = ["explicit", "implicit"]

    means = [
        [data[model][cat]["mean"] for cat in state_tracking_kinds] for model in models
    ]
    std_errs = [
        [data[model][cat]["std_err"] for cat in state_tracking_kinds]
        for model in models
    ]
    cmap = plt.get_cmap("Paired")
    colors = np.array([cmap(i) for i in range(len(state_tracking_kinds))])

    # Plotting
    x = np.arange(len(models))  # the label locations

    width = 0.4

    fig, ax = plt.subplots(1, 1, figsize=(8, 6), dpi=300)

    explicit_bars = ax.barh(
        x + width / 2,
        [mean[0] for mean in means],
        width,
        xerr=[err[0] for err in std_errs],
        label="Explicitly tracked state baseline",
        color=colors[0],
    )
    implicit_bars = ax.barh(
        x - width / 2,
        [mean[1] for mean in means],
        width,
        xerr=[err[1] for err in std_errs],
        label="Implicitly tracked state",
        color=colors[1],
    )

    ax.set_xlabel(STAT_TO_LABEL[stat])
    # maximum x + xerr value times 1.2
    x_max = (
        max([m for mean in means for m in mean])
        + max([e for err in std_errs for e in err])
    ) * 1.2
    ax.set_xlim([0, x_max])
    ax.set_yticks(x)
    ax.set_yticklabels(models)

    ax.bar_label(implicit_bars, padding=3, fmt="%.2f")
    ax.bar_label(explicit_bars, padding=3, fmt="%.2f")

    # plot random and human baselines
    random_baseline = data["random_baseline"]["implicit"]["mean"]
    random_err = data["random_baseline"]["implicit"]["std_err"]
    ax.axvline(random_baseline, color="red", linestyle="--", label="Random baseline")
    ax.axvspan(
        random_baseline - random_err,
        random_baseline + random_err,
        color="red",
        alpha=0.05,
    )

    human_baseline = data["human_baseline"]["implicit"]["mean"]
    human_err = data["human_baseline"]["implicit"]["std_err"]
    ax.axvline(
        human_baseline,
        color="#366a9d",
        linestyle=":",
        label="Human baseline (implicit)",
    )

    ax.axvspan(
        human_baseline - human_err,
        human_baseline + human_err,
        color="#366a9d",
        alpha=0.05,
    )

    # get rid of horizontal grid lines
    ax.grid(axis="y", which="both")

    ax.legend()

    fig.tight_layout()

    plt.savefig(save_path, bbox_inches="tight", dpi=300)