in evals/elsuite/already_said_that/scripts/make_plots.py [0:0]
def make_bar_plot(results_dict: dict, stat: str, save_path: Path):
sns.set_context("paper")
sns.set_style("whitegrid")
fig, ax = plt.subplots(1, 1, figsize=(8, 7), dpi=300)
data = results_dict[stat]
# the random baseline isn't plotted as bars
models = MODELS[:-1]
distractors = [
"which-is-heavier",
"ambiguous-sentences",
"first-letters",
"reverse-sort-words-eng",
]
width = 0.15
if stat != "avg_distractor_accuracy":
distractors.append("distractorless")
diffs = [-width * 2, -width / 1, 0, width / 1, width * 2]
ax.axvline(STAT_TO_MAX[stat], label="maximum", linestyle="--", color="grey")
# random baseline is roughly the same for all distractors; pick one for simplicity
random_baseline = data["first-letters"]["random_baseline"]["mean"]
ax.axvline(
random_baseline,
label=MODEL_TO_LABEL["random_baseline"],
linestyle="-.",
color="black",
)
# make legend order match bar order, idk why matplotlib reverses them
legend_indices = [0, 1, 6, 5, 4, 3, 2]
else:
diffs = [-width * 1.5, -width / 2, width / 2, width * 1.5]
legend_indices = list(range(len(distractors)))[::-1]
means = [[data[dis][model]["mean"] for dis in distractors] for model in models]
std_errs = [
[data[dis][model]["std_err"] for dis in distractors] for model in models
]
cmap = plt.get_cmap("Set3")
colors = np.array([cmap(i) for i in range(len(distractors))])
x = np.arange(len(models)) # the label locations
distractor_bars = []
for i, distractor in enumerate(distractors):
bar = ax.barh(
x + diffs[i],
[mean[i] for mean in means],
width,
xerr=[err[i] for err in std_errs],
label=distractor,
color=colors[i] if distractor != "distractorless" else "black",
)
distractor_bars.append(bar)
ax.set_xlabel(STAT_TO_LABEL[stat])
x_max = STAT_TO_MAX[stat] + 0.05 * STAT_TO_MAX[stat]
ax.set_xlim([0, x_max])
ax.set_yticks(x)
ax.set_yticklabels([MODEL_TO_LABEL[model] for model in models])
handles, labels = ax.get_legend_handles_labels()
ax.legend(
[handles[i] for i in legend_indices],
[labels[i] for i in legend_indices],
loc="best",
)
for bar, distractor in zip(distractor_bars, distractors):
ax.bar_label(
bar,
label_type="edge",
fmt="%.2f",
# color="white" if distractor == "distractorless" else "black",
fontsize=8,
)
# get rid of horizontal grid lines
ax.grid(axis="y", which="both")
fig.set_tight_layout(True)
plt.savefig(save_path, bbox_inches="tight", dpi=300)