def rank_stability_plot()

in leaderboard/plots.py [0:0]


def rank_stability_plot(filetypes: List[str], commit: bool = False):
    test_irt_parsed = IrtParsed.from_irt_file(
        Path(conf["irt"]["squad"]["test"]["pyro"]["3PL"]["full"]) / "parameters.json"
    )
    test_preds = LeaderboardPredictions.parse_file(conf["squad"]["submission_predictions"]["test"])
    mapping = read_json(conf["squad"]["dev_to_test"])
    dev_to_test = mapping["dev_to_test"]
    df = create_rank_stability_df(
        dev_to_test=dev_to_test, test_preds=test_preds, test_irt_parsed=test_irt_parsed
    )

    names = {
        "abs_irt_corr": "IRT to IRT",
        "classical_corr": "Acc to Acc",
        "test_classical_sample_classical_corr": "Acc to Acc",
        "test_classical_sample_irt_corr": "IRT to Acc",
        "test_irt_sample_classical_corr": "Acc to IRT",
        "test_irt_sample_irt_corr": "IRT to IRT",
    }
    color_order = ["IRT to IRT", "Acc to Acc", "IRT to Acc", "Acc to IRT"]

    melt_df = df.drop(columns=["irt_corr"]).melt(id_vars=["trial_size", "trial_id"]).dropna(axis=0)
    excluded = ["IRT to Acc", "Acc to IRT"]
    console.log(melt_df.head())
    melt_df["correlation"] = melt_df["variable"].map(lambda v: names[v])
    melt_df = melt_df[melt_df["correlation"].map(lambda v: v not in excluded)]
    melt_df["experiment"] = melt_df["variable"].map(label_experiment)
    base = alt.Chart(melt_df).encode(
        x=alt.X(
            "trial_size",
            title="Development Set Sample Size",
            scale=alt.Scale(type="log", base=2, domain=[16, 6000]),
        ),
        color=alt.Color(
            "correlation",
            title="Correlation",
            scale=alt.Scale(scheme="category10"),
            sort=color_order,
            legend=alt.Legend(
                symbolOpacity=1,
                symbolType="circle",
                symbolStrokeWidth=3,
                orient="none",
                legendX=570,
                legendY=105,
                fillColor="white",
                strokeColor="gray",
                padding=5,
            ),
        ),
    )
    y_title = "Kendall Rank Correlation"
    line = base.mark_line(opacity=0.7).encode(
        y=alt.Y("mean(value):Q", scale=alt.Scale(zero=False), title=y_title),
    )
    band = base.mark_errorband(extent="ci").encode(
        y=alt.Y("value", title=y_title, scale=alt.Scale(zero=False)),
        color=alt.Color("correlation", sort=color_order),
    )
    font_size = 14
    chart = (
        (band + line)
        .properties(width=300, height=170)
        .facet(alt.Column("experiment", title=""))
        .configure_header(titleFontSize=font_size, labelFontSize=font_size)
        .resolve_axis(y="independent")
        .configure(padding=0)
    )
    if commit:
        save_chart(chart, COMMIT_AUTO_FIGS / "stability_simulation_corr", filetypes)
    else:
        save_chart(chart, AUTO_FIG / "stability_simulation_corr", filetypes)