def plot_paired_ttest_nsamples()

in src/alpaca_eval/plotting.py [0:0]


def plot_paired_ttest_nsamples(df):
    df_ttest = _get_ttest_df(df)
    all_sub_ttest_df = {
        n: _get_ttest_df(df, n_samples=n, random_state=123, sorted_idx=list(df_ttest.index))
        for n in range(50, len(df["instruction"].unique()), 50)
    }

    arr_min_samples = np.minimum.reduce([np.where(v < 0.05, k, float("inf")) for k, v in all_sub_ttest_df.items()])
    arr_min_samples[np.isinf(arr_min_samples)] = np.nan
    df_min_samples = pd.DataFrame(arr_min_samples, index=df_ttest.index, columns=df_ttest.index)

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 15))
    with plot_config(font_scale=0.55):
        sns.heatmap(
            df_min_samples.isnull(),
            cbar=False,
            color="black",
            alpha=0.5,
            mask=~df_min_samples.isnull() | np.triu(np.ones_like(df_ttest, dtype=bool), k=0),
        )
        g = sns.heatmap(
            df_min_samples,
            annot=True,
            fmt=".0f",
            cbar=False,
            square=True,
            xticklabels=False,
            ax=ax,
            vmin=0,
            vmax=1000,
            cmap=sns.color_palette("rocket_r", as_cmap=True),
            mask=np.triu(np.ones_like(df_ttest, dtype=bool)),
        )

        g.set(xlabel="", ylabel="")

    plt.show()

    return g