in evals/elsuite/self_prompting/scripts/make_plots.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", "-d", type=str, required=True)
parser.add_argument("--out_dir", "-o", type=str, default="./outputs")
args = parser.parse_args()
log_dir = Path(args.log_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
metrics_df = extract_metrics(log_dir)
# Our results are an average over different task distributions, handle with care
if set(metrics_df["taskname"].unique()) != set(eval_list):
print(
"WARNING: Task distribution changed, results and error bars will not be comparable to plots with the original task distribution."
)
# Sample a subset of the data for inspection
subset_df = metrics_df[metrics_df["tasker_model"] != "mean"]
# Take only the first row of each [solver_path, taskname, tasker_model] group
subset_df = subset_df.groupby(["solver_path", "taskname", "tasker_model"]).first().reset_index()
subset_df.to_csv(out_dir / "subset_samples.csv", quoting=csv.QUOTE_ALL, escapechar="\\")
make_plot(metrics_df, out_dir / "per_tasker_results_exact.png", metric="exact")
make_plot(metrics_df, out_dir / "per_tasker_results_fuzzy.png", metric="fuzzy")
# Print results
exact_df_rows = []
fuzzy_df_rows = []
violation_df_rows = []
for _, df_tasker in metrics_df.groupby(["model", "solver"]):
solver = df_tasker["solver"].iloc[0]
model = df_tasker["model"].iloc[0]
exact = df_tasker.groupby("tasker_model")["exact"].mean()
exact_df_rows.append(
{
"model": model,
"solver": solver,
**exact,
}
)
fuzzy = df_tasker.groupby("tasker_model")["fuzzy"].mean()
fuzzy_df_rows.append(
{
"model": model,
"solver": solver,
**fuzzy,
}
)
prompt_rule_violation = df_tasker.groupby("tasker_model")["prompt_rule_violation"].mean()
violation_df_rows.append(
{
"model": model,
"solver": solver,
**prompt_rule_violation,
}
)
exact_df = pd.DataFrame(exact_df_rows)
exact_df.to_csv(out_dir / "exact.csv", quoting=csv.QUOTE_ALL, index=False)
print(exact_df)
fuzzy_df = pd.DataFrame(fuzzy_df_rows)
fuzzy_df.to_csv(out_dir / "fuzzy.csv", quoting=csv.QUOTE_ALL, index=False)
print(fuzzy_df)
violation_df = pd.DataFrame(violation_df_rows)
violation_df.to_csv(out_dir / "violation.csv", quoting=csv.QUOTE_ALL, index=False)
print(violation_df)