in leaderboard/plots.py [0:0]
def sample_stability_plot(filetypes: List[str], commit: bool = False):
input_dir = Path(conf["stability"]["sampling"])
random_df = pd.read_json(input_dir / "random_df.json")
irt_df = pd.read_json(input_dir / "irt_df.json")
info_df = pd.read_json(input_dir / "info_df.json")
method_names = {
"dev_high_disc_to_test": "High Discrimination",
"dev_high_diff_to_test": "High Difficulty",
"dev_high_disc_diff_to_test": "High Disc + Diff",
"dev_info_to_test": "High Information",
"dev_random_to_test": "Random",
}
def format_df(dataframe):
return dataframe.assign(
sampling_method=dataframe["variable"].map(lambda v: method_names[v])
)
x_scale = alt.X("trial_size", title="Development Set Sample Size", scale=alt.Scale(type="log"))
y_scale = alt.Scale(zero=False)
color_scale = alt.Color(
"sampling_method",
title="Sampling Method",
legend=alt.Legend(orient="bottom-right", fillColor="white", padding=5, strokeColor="gray"),
sort=[
"High Disc + Diff",
"High Information",
"High Discrimination",
"High Difficulty",
"Random",
],
)
random_line = (
alt.Chart(format_df(random_df))
.mark_line()
.encode(
x=x_scale,
y=alt.Y("mean(value)", scale=y_scale, title="Correlation to Test Rank"),
color=color_scale,
)
)
random_band = (
alt.Chart(format_df(random_df))
.mark_errorband(extent="ci")
.encode(x=x_scale, y=alt.Y("value", title="", scale=y_scale), color=color_scale)
)
determ_df = pd.concat([irt_df, info_df])
irt_line = (
alt.Chart(format_df(determ_df))
.mark_line()
.encode(x=x_scale, y=alt.Y("value", title="", scale=y_scale), color=color_scale)
)
font_size = 18
chart = (
(random_band + random_line + irt_line)
.configure_axis(labelFontSize=font_size, titleFontSize=font_size)
.configure_legend(
labelFontSize=font_size, titleFontSize=font_size, symbolLimit=0, labelLimit=0,
)
.configure_header(labelFontSize=font_size)
.configure(padding=0)
)
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "sampling_rank", filetypes)
else:
save_chart(chart, AUTO_FIG / "sampling_rank", filetypes)