def main()

in evals/elsuite/self_prompting/scripts/make_plots.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_dir", "-d", type=str, required=True)
    parser.add_argument("--out_dir", "-o", type=str, default="./outputs")
    args = parser.parse_args()
    log_dir = Path(args.log_dir)
    out_dir = Path(args.out_dir)

    out_dir.mkdir(exist_ok=True, parents=True)

    metrics_df = extract_metrics(log_dir)

    # Our results are an average over different task distributions, handle with care
    if set(metrics_df["taskname"].unique()) != set(eval_list):
        print(
            "WARNING: Task distribution changed, results and error bars will not be comparable to plots with the original task distribution."
        )

    # Sample a subset of the data for inspection
    subset_df = metrics_df[metrics_df["tasker_model"] != "mean"]
    # Take only the first row of each [solver_path, taskname, tasker_model] group
    subset_df = subset_df.groupby(["solver_path", "taskname", "tasker_model"]).first().reset_index()
    subset_df.to_csv(out_dir / "subset_samples.csv", quoting=csv.QUOTE_ALL, escapechar="\\")

    make_plot(metrics_df, out_dir / "per_tasker_results_exact.png", metric="exact")
    make_plot(metrics_df, out_dir / "per_tasker_results_fuzzy.png", metric="fuzzy")

    # Print results
    exact_df_rows = []
    fuzzy_df_rows = []
    violation_df_rows = []
    for _, df_tasker in metrics_df.groupby(["model", "solver"]):
        solver = df_tasker["solver"].iloc[0]
        model = df_tasker["model"].iloc[0]

        exact = df_tasker.groupby("tasker_model")["exact"].mean()
        exact_df_rows.append(
            {
                "model": model,
                "solver": solver,
                **exact,
            }
        )
        fuzzy = df_tasker.groupby("tasker_model")["fuzzy"].mean()
        fuzzy_df_rows.append(
            {
                "model": model,
                "solver": solver,
                **fuzzy,
            }
        )
        prompt_rule_violation = df_tasker.groupby("tasker_model")["prompt_rule_violation"].mean()
        violation_df_rows.append(
            {
                "model": model,
                "solver": solver,
                **prompt_rule_violation,
            }
        )

    exact_df = pd.DataFrame(exact_df_rows)
    exact_df.to_csv(out_dir / "exact.csv", quoting=csv.QUOTE_ALL, index=False)
    print(exact_df)
    fuzzy_df = pd.DataFrame(fuzzy_df_rows)
    fuzzy_df.to_csv(out_dir / "fuzzy.csv", quoting=csv.QUOTE_ALL, index=False)
    print(fuzzy_df)
    violation_df = pd.DataFrame(violation_df_rows)
    violation_df.to_csv(out_dir / "violation.csv", quoting=csv.QUOTE_ALL, index=False)
    print(violation_df)