in evals/elsuite/track_the_stat/scripts/make_plots.py [0:0]
def make_bar_plot(results_dict: dict, task: str, stat: str, save_path: Path):
sns.set_context("paper")
sns.set_style("whitegrid")
data = results_dict[stat][task]
# the random baseline and human baseline aren't plotted as bars
models = MODELS[:-2]
state_tracking_kinds = ["explicit", "implicit"]
means = [
[data[model][cat]["mean"] for cat in state_tracking_kinds] for model in models
]
std_errs = [
[data[model][cat]["std_err"] for cat in state_tracking_kinds]
for model in models
]
cmap = plt.get_cmap("Paired")
colors = np.array([cmap(i) for i in range(len(state_tracking_kinds))])
# Plotting
x = np.arange(len(models)) # the label locations
width = 0.4
fig, ax = plt.subplots(1, 1, figsize=(8, 6), dpi=300)
explicit_bars = ax.barh(
x + width / 2,
[mean[0] for mean in means],
width,
xerr=[err[0] for err in std_errs],
label="Explicitly tracked state baseline",
color=colors[0],
)
implicit_bars = ax.barh(
x - width / 2,
[mean[1] for mean in means],
width,
xerr=[err[1] for err in std_errs],
label="Implicitly tracked state",
color=colors[1],
)
ax.set_xlabel(STAT_TO_LABEL[stat])
# maximum x + xerr value times 1.2
x_max = (
max([m for mean in means for m in mean])
+ max([e for err in std_errs for e in err])
) * 1.2
ax.set_xlim([0, x_max])
ax.set_yticks(x)
ax.set_yticklabels(models)
ax.bar_label(implicit_bars, padding=3, fmt="%.2f")
ax.bar_label(explicit_bars, padding=3, fmt="%.2f")
# plot random and human baselines
random_baseline = data["random_baseline"]["implicit"]["mean"]
random_err = data["random_baseline"]["implicit"]["std_err"]
ax.axvline(random_baseline, color="red", linestyle="--", label="Random baseline")
ax.axvspan(
random_baseline - random_err,
random_baseline + random_err,
color="red",
alpha=0.05,
)
human_baseline = data["human_baseline"]["implicit"]["mean"]
human_err = data["human_baseline"]["implicit"]["std_err"]
ax.axvline(
human_baseline,
color="#366a9d",
linestyle=":",
label="Human baseline (implicit)",
)
ax.axvspan(
human_baseline - human_err,
human_baseline + human_err,
color="#366a9d",
alpha=0.05,
)
# get rid of horizontal grid lines
ax.grid(axis="y", which="both")
ax.legend()
fig.tight_layout()
plt.savefig(save_path, bbox_inches="tight", dpi=300)