in leaderboard/plots.py [0:0]
def confusion_plot(filetypes: List[str], commit: bool = False, irt_model: str = "3PL"):
dev_irt_params = IrtParsed.from_irt_file(
Path(conf["irt"]["squad"]["dev"]["pyro"][irt_model]["full"]) / "parameters.json"
)
dev_predictions = LeaderboardPredictions.parse_file(
conf["squad"]["submission_predictions"]["dev"]
)
dev_id_to_subject = load_squad_submissions(dev_predictions)
mapping = read_json(conf["squad"]["dev_to_test"])
dev_to_test = mapping["dev_to_test"]
test_irt_params = IrtParsed.from_irt_file(
Path(conf["irt"]["squad"]["test"]["pyro"][irt_model]["full"]) / "parameters.json"
)
subject_df, id_to_subject_stats = create_subject_df(
dev_id_to_subject=dev_id_to_subject,
dev_irt_params=dev_irt_params,
dev_to_test=dev_to_test,
test_irt_params=test_irt_params,
)
item_df = create_item_df(dev_irt_params)
diff_quantiles = {
"Low": item_df["diff"].quantile(0.25),
"Med-Low": item_df["diff"].quantile(0.5),
"Med-High": item_df["diff"].quantile(0.75),
"High": item_df["diff"].quantile(1),
}
disc_quantiles = {
"Low": item_df["disc"].quantile(0.25),
"Med-Low": item_df["disc"].quantile(0.5),
"Med-High": item_df["disc"].quantile(0.75),
"High": item_df["disc"].quantile(1),
}
df = create_confusion_df(
dev_predictions=dev_predictions,
dev_irt_params=dev_irt_params,
diff_quantiles=diff_quantiles,
disc_quantiles=disc_quantiles,
)
by_diff = df.groupby(["subject_id", "diff_cat"]).mean("n").reset_index()
by_diff["category"] = by_diff["diff_cat"]
by_diff["parameter"] = "Difficulty"
by_disc = df.groupby(["subject_id", "disc_cat"]).mean("n").reset_index()
by_disc["category"] = by_disc["disc_cat"]
by_disc["parameter"] = "Discriminability"
combined = pd.concat([by_diff, by_disc])
def format_name(name):
parens = name.split("(")
return parens[0]
combined["name"] = combined["subject_id"].map(
lambda sid: format_name(id_to_subject_stats[sid]["name"])
)
combined["test_em"] = combined["subject_id"].map(
lambda sid: id_to_subject_stats[sid]["test_em"]
)
combined["percent"] = combined["response"].map(lambda p: round(100 * p))
selected_subject_ids = {
# overfitting dev_em
# "0x7a3431e690444afca4988f927ad23019": ["Overfit"],
"0xa5265f8dbc424109a4573494c113235d": ["Overfit"],
# top dev_skill
# "0x8978eb3bd032447a80f27b2b82ad3b80": [
# "Top Dev Ability",
# "Top Test EM",
# ], # also top test_em
# "0xe56a3accea374f9787255a85febd8497": ["Top Model"],
# top test_em
# "0xc81e2e3395dd447eb85c899aa93d0d16": ["Top Model"],
# "0x082fc49949b14c6aa3827bfebed5cc40": ["Top Model"],
# top test skill
"0x8978eb3bd032447a80f27b2b82ad3b80": ["Top Model"],
# lowest dev_skill
"0x8a3b01f4ded748df8b657684212372b4": ["Bottom Model"],
# "0x2d5cf8f56e164de8837cb8ed30f15f59": ["Bottom Model"],
# "0xfe18a19d54d44e2eaefd68836c3b388b": ["Hallmark"],
# "0xfcd2efa17551478f96c593fe07eebd97": ["Hallmark"],
"0xeb6fe173849a495b83eb4e56b172e02a": ["Hallmark"],
"0x843f0d9f242f46b9803558614bff2f86": ["Hallmark"],
}
combined_filtered = combined[combined["subject_id"].map(lambda x: x in selected_subject_ids)]
label_rows = []
for sid, labels in selected_subject_ids.items():
for l in labels:
label_rows.append(
{
"subject_id": sid,
"name": format_name(id_to_subject_stats[sid]["name"]),
"label": l,
"n": 1,
"test_em": id_to_subject_stats[sid]["test_em"],
"test_em_percent": round(id_to_subject_stats[sid]["test_em"]),
}
)
label_df = pd.DataFrame(label_rows)
subject_order = label_df.sort_values("test_em", ascending=False).name.tolist()
label_chart = (
alt.Chart(label_df)
.mark_bar()
.encode(
x=alt.X(
"n", title="", axis=alt.Axis(labels=False, grid=False, ticks=False, domain=False),
),
y=alt.Y(
"name",
title="Name",
sort=subject_order,
axis=alt.Axis(labels=True, ticks=True, orient="left"),
),
color=alt.Color(
"label",
title="Description",
legend=alt.Legend(
orient="left", offset=0, # , direction="vertical", legendX=-250, legendY=50
),
scale=alt.Scale(scheme="set2"),
),
)
)
label_text = (
alt.Chart(label_df)
.mark_text(baseline="middle", dx=-12, color="black")
.encode(
x=alt.X("n", title="Test Acc", axis=alt.Axis(orient="top")),
y=alt.Y("name", sort=subject_order),
text=alt.Text("test_em_percent"),
)
)
label_chart = (label_chart + label_text).properties(width=25)
main_chart = (
alt.Chart(combined_filtered)
.mark_rect(stroke="black", strokeWidth=0.1)
.encode(
x=alt.X("category", title="", sort=order, axis=alt.Axis(labelAngle=-35)),
y=alt.Y(
"name",
title="",
sort=subject_order,
axis=alt.Axis(titlePadding=10, labels=False),
scale=alt.Scale(paddingInner=0.2),
),
color=alt.Color(
"response",
title="Dev Acc",
scale=alt.Scale(scheme="magma"),
legend=alt.Legend(offset=0),
),
)
)
text = (
alt.Chart(combined_filtered)
.mark_text(baseline="middle")
.encode(
x=alt.X("category", sort=order),
y=alt.Y("name", sort=subject_order),
text=alt.Text("percent"),
color=alt.condition(alt.datum.percent > 70, alt.value("black"), alt.value("white")),
)
)
chart = main_chart + text
hcat = alt.hconcat()
for param in ["Difficulty", "Discriminability"]:
hcat |= chart.transform_filter(datum.parameter == param).properties(title=param)
chart = alt.hconcat(label_chart, hcat, spacing=0).configure(padding=0)
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "irt_confusion", filetypes)
else:
save_chart(chart, AUTO_FIG / "irt_confusion", filetypes)