def make_plot()

in evals/elsuite/theory_of_mind/scripts/make_plots.py [0:0]


def make_plot(df, out_dir):
    sns.set_theme(style="whitegrid")
    sns.set_palette("dark")
    # Define the order of models
    model_order = ["gpt-3.5-turbo", "gpt-4-base", "gpt-4"]
    datasets = df["dataset"].unique()

    for dataset in datasets:
        ds = df[df["dataset"] == dataset.lower()]

        # Ensure the model column is a categorical type with the specified order
        ds["model"] = pd.Categorical(ds["model"], categories=model_order, ordered=True)
        ds = ds.sort_values("model")  # Sort according to the categorical order

        # Unique models
        xs = ds["model"].unique()
        # Get the accuracy values for both prompt types
        simple_acc = ds[ds["prompt_type"] == "simple"]["accuracy"].values
        cot_acc = ds[ds["prompt_type"] == "cot"]["accuracy"].values

        # Get the corresponding error values from the "bootstrap_std" field
        simple_std = ds[ds["prompt_type"] == "simple"]["bootstrap_std"].values
        cot_std = ds[ds["prompt_type"] == "cot"]["bootstrap_std"].values

        # Define the width of a bar
        bar_width = 0.35
        # Set the positions of the bars
        x_indices = np.arange(len(xs))
        x_indices2 = [x + bar_width for x in x_indices]

        fig, ax1 = plt.subplots()
        fig.suptitle(f"Accuracy on {dataset} dataset")

        ax1.set_xlabel("Model")
        ax1.set_ylabel("Accuracy")

        # Plot the bars for 'simple' and 'cot'
        ax1.bar(
            x_indices,
            simple_acc,
            width=bar_width,
            color=sns.color_palette("pastel")[0],
            yerr=simple_std,
            label="simple",
        )
        ax1.bar(
            x_indices2,
            cot_acc,
            width=bar_width,
            color=sns.color_palette("pastel")[1],
            yerr=cot_std,
            label="chain-of-thought",
        )

        if dataset == "socialiqa":
            # Draw the horizontal line for the human baseline
            human_baseline = 0.881
            ax1.axhline(y=human_baseline, color="gray", linestyle="--", linewidth=1)
            # Add the text label for the human baseline
            ax1.text(
                0.01, human_baseline, "human baseline", va="center", ha="left", backgroundcolor="w"
            )

        # Set the x-axis ticks to be in the middle of the two bars
        ax1.set_xticks([r + bar_width / 2 for r in range(len(xs))])
        ax1.set_xticklabels(xs, rotation=45)  # Rotate the x-axis labels if needed

        ax1.set_ylim(0, 1)

        # Add legend
        ax1.legend(loc="upper right", bbox_to_anchor=(1, 1))

        # Save the figure
        plt.savefig(out_dir / f"accuracy_{dataset.lower()}.png", bbox_inches="tight")
        plt.tight_layout()  # Adjust the plot to ensure everything fits without overlapping
        plt.show()