def make_bar_plot()

in evals/elsuite/already_said_that/scripts/make_plots.py [0:0]


def make_bar_plot(results_dict: dict, stat: str, save_path: Path):
    sns.set_context("paper")
    sns.set_style("whitegrid")

    fig, ax = plt.subplots(1, 1, figsize=(8, 7), dpi=300)

    data = results_dict[stat]

    # the random baseline isn't plotted as bars
    models = MODELS[:-1]

    distractors = [
        "which-is-heavier",
        "ambiguous-sentences",
        "first-letters",
        "reverse-sort-words-eng",
    ]

    width = 0.15
    if stat != "avg_distractor_accuracy":
        distractors.append("distractorless")
        diffs = [-width * 2, -width / 1, 0, width / 1, width * 2]
        ax.axvline(STAT_TO_MAX[stat], label="maximum", linestyle="--", color="grey")

        # random baseline is roughly the same for all distractors; pick one for simplicity
        random_baseline = data["first-letters"]["random_baseline"]["mean"]

        ax.axvline(
            random_baseline,
            label=MODEL_TO_LABEL["random_baseline"],
            linestyle="-.",
            color="black",
        )

        # make legend order match bar order, idk why matplotlib reverses them
        legend_indices = [0, 1, 6, 5, 4, 3, 2]
    else:
        diffs = [-width * 1.5, -width / 2, width / 2, width * 1.5]
        legend_indices = list(range(len(distractors)))[::-1]

    means = [[data[dis][model]["mean"] for dis in distractors] for model in models]
    std_errs = [
        [data[dis][model]["std_err"] for dis in distractors] for model in models
    ]
    cmap = plt.get_cmap("Set3")
    colors = np.array([cmap(i) for i in range(len(distractors))])

    x = np.arange(len(models))  # the label locations

    distractor_bars = []
    for i, distractor in enumerate(distractors):
        bar = ax.barh(
            x + diffs[i],
            [mean[i] for mean in means],
            width,
            xerr=[err[i] for err in std_errs],
            label=distractor,
            color=colors[i] if distractor != "distractorless" else "black",
        )
        distractor_bars.append(bar)

    ax.set_xlabel(STAT_TO_LABEL[stat])
    x_max = STAT_TO_MAX[stat] + 0.05 * STAT_TO_MAX[stat]
    ax.set_xlim([0, x_max])
    ax.set_yticks(x)
    ax.set_yticklabels([MODEL_TO_LABEL[model] for model in models])
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        [handles[i] for i in legend_indices],
        [labels[i] for i in legend_indices],
        loc="best",
    )

    for bar, distractor in zip(distractor_bars, distractors):
        ax.bar_label(
            bar,
            label_type="edge",
            fmt="%.2f",
            # color="white" if distractor == "distractorless" else "black",
            fontsize=8,
        )

    # get rid of horizontal grid lines
    ax.grid(axis="y", which="both")

    fig.set_tight_layout(True)

    plt.savefig(save_path, bbox_inches="tight", dpi=300)