def sample_stability_plot()

in leaderboard/plots.py [0:0]


def sample_stability_plot(filetypes: List[str], commit: bool = False):
    input_dir = Path(conf["stability"]["sampling"])
    random_df = pd.read_json(input_dir / "random_df.json")
    irt_df = pd.read_json(input_dir / "irt_df.json")
    info_df = pd.read_json(input_dir / "info_df.json")

    method_names = {
        "dev_high_disc_to_test": "High Discrimination",
        "dev_high_diff_to_test": "High Difficulty",
        "dev_high_disc_diff_to_test": "High Disc + Diff",
        "dev_info_to_test": "High Information",
        "dev_random_to_test": "Random",
    }

    def format_df(dataframe):
        return dataframe.assign(
            sampling_method=dataframe["variable"].map(lambda v: method_names[v])
        )

    x_scale = alt.X("trial_size", title="Development Set Sample Size", scale=alt.Scale(type="log"))
    y_scale = alt.Scale(zero=False)
    color_scale = alt.Color(
        "sampling_method",
        title="Sampling Method",
        legend=alt.Legend(orient="bottom-right", fillColor="white", padding=5, strokeColor="gray"),
        sort=[
            "High Disc + Diff",
            "High Information",
            "High Discrimination",
            "High Difficulty",
            "Random",
        ],
    )
    random_line = (
        alt.Chart(format_df(random_df))
        .mark_line()
        .encode(
            x=x_scale,
            y=alt.Y("mean(value)", scale=y_scale, title="Correlation to Test Rank"),
            color=color_scale,
        )
    )
    random_band = (
        alt.Chart(format_df(random_df))
        .mark_errorband(extent="ci")
        .encode(x=x_scale, y=alt.Y("value", title="", scale=y_scale), color=color_scale)
    )

    determ_df = pd.concat([irt_df, info_df])
    irt_line = (
        alt.Chart(format_df(determ_df))
        .mark_line()
        .encode(x=x_scale, y=alt.Y("value", title="", scale=y_scale), color=color_scale)
    )
    font_size = 18
    chart = (
        (random_band + random_line + irt_line)
        .configure_axis(labelFontSize=font_size, titleFontSize=font_size)
        .configure_legend(
            labelFontSize=font_size, titleFontSize=font_size, symbolLimit=0, labelLimit=0,
        )
        .configure_header(labelFontSize=font_size)
        .configure(padding=0)
    )

    if commit:
        save_chart(chart, COMMIT_AUTO_FIGS / "sampling_rank", filetypes)
    else:
        save_chart(chart, AUTO_FIG / "sampling_rank", filetypes)