def plot_regression()

in project/paperbench/experiments/pbcd_correlation/plot.py [0:0]


def plot_regression(code_only_scores, normal_scores, model):
    """
    Create and save a scatter plot with regression line.

    Args:
        code_only_scores (np.ndarray): Scores for code-only condition
        normal_scores (np.ndarray): Scores for normal condition
        model (LinearRegression): Fitted regression model
    """
    x_range = np.linspace(0.2, 0.7, 50)
    y_fit = model.predict(x_range.reshape(-1, 1))

    plt.rcParams.update({"font.size": 7})

    fig, ax = plt.subplots(figsize=(6.75133 / 1.5, 2.75))
    ax.scatter(code_only_scores, normal_scores, s=10)
    ax.plot(x_range, y_fit, "r-")

    # Add line of best fit equation and R² value
    r_squared = model.score(code_only_scores.reshape(-1, 1), normal_scores)
    equation = f"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}"
    r2_text = f"R² = {r_squared:.3f}"
    ax.text(
        0.05,
        0.95,
        f"{equation}\n{r2_text}",
        transform=ax.transAxes,
        verticalalignment="top",
        bbox=dict(facecolor="white", alpha=0.8),
    )

    ax.set_xlabel("PaperBench Code-Dev Performance")
    ax.set_ylabel("PaperBench Performance")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(
        "experiments/pbcd_correlation/correlation_plot.pdf",
        bbox_inches="tight",
        dpi=300,
        pad_inches=0.01,
    )
    # plt.show()
    plt.close()