in leaderboard/plots.py [0:0]
def rank_stability_plot(filetypes: List[str], commit: bool = False):
test_irt_parsed = IrtParsed.from_irt_file(
Path(conf["irt"]["squad"]["test"]["pyro"]["3PL"]["full"]) / "parameters.json"
)
test_preds = LeaderboardPredictions.parse_file(conf["squad"]["submission_predictions"]["test"])
mapping = read_json(conf["squad"]["dev_to_test"])
dev_to_test = mapping["dev_to_test"]
df = create_rank_stability_df(
dev_to_test=dev_to_test, test_preds=test_preds, test_irt_parsed=test_irt_parsed
)
names = {
"abs_irt_corr": "IRT to IRT",
"classical_corr": "Acc to Acc",
"test_classical_sample_classical_corr": "Acc to Acc",
"test_classical_sample_irt_corr": "IRT to Acc",
"test_irt_sample_classical_corr": "Acc to IRT",
"test_irt_sample_irt_corr": "IRT to IRT",
}
color_order = ["IRT to IRT", "Acc to Acc", "IRT to Acc", "Acc to IRT"]
melt_df = df.drop(columns=["irt_corr"]).melt(id_vars=["trial_size", "trial_id"]).dropna(axis=0)
excluded = ["IRT to Acc", "Acc to IRT"]
console.log(melt_df.head())
melt_df["correlation"] = melt_df["variable"].map(lambda v: names[v])
melt_df = melt_df[melt_df["correlation"].map(lambda v: v not in excluded)]
melt_df["experiment"] = melt_df["variable"].map(label_experiment)
base = alt.Chart(melt_df).encode(
x=alt.X(
"trial_size",
title="Development Set Sample Size",
scale=alt.Scale(type="log", base=2, domain=[16, 6000]),
),
color=alt.Color(
"correlation",
title="Correlation",
scale=alt.Scale(scheme="category10"),
sort=color_order,
legend=alt.Legend(
symbolOpacity=1,
symbolType="circle",
symbolStrokeWidth=3,
orient="none",
legendX=570,
legendY=105,
fillColor="white",
strokeColor="gray",
padding=5,
),
),
)
y_title = "Kendall Rank Correlation"
line = base.mark_line(opacity=0.7).encode(
y=alt.Y("mean(value):Q", scale=alt.Scale(zero=False), title=y_title),
)
band = base.mark_errorband(extent="ci").encode(
y=alt.Y("value", title=y_title, scale=alt.Scale(zero=False)),
color=alt.Color("correlation", sort=color_order),
)
font_size = 14
chart = (
(band + line)
.properties(width=300, height=170)
.facet(alt.Column("experiment", title=""))
.configure_header(titleFontSize=font_size, labelFontSize=font_size)
.resolve_axis(y="independent")
.configure(padding=0)
)
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "stability_simulation_corr", filetypes)
else:
save_chart(chart, AUTO_FIG / "stability_simulation_corr", filetypes)