in leaderboard/plots.py [0:0]
def plot_irt_acc(filetypes: List[str], commit: bool = False):
base_parsed_irt = IrtParsed.from_irt_file(
Path(conf["irt"]["squad"]["dev"]["pyro"]["3PL"]["full"]) / "parameters.json"
)
example_ids = base_parsed_irt.example_ids
data = LeaderboardPredictions.parse_file(conf["squad"]["submission_predictions"]["dev"])
example_accuracy = defaultdict(float)
for model_scores in data.scored_predictions.values():
for ex_id in example_ids:
example_accuracy[ex_id] += model_scores["exact_match"][ex_id]
n_models = len(data.scored_predictions)
for ex_id, correct in example_accuracy.items():
example_accuracy[ex_id] = correct / n_models
rows = []
for irt_type in ("1PL", "2PL", "3PL"):
parsed_irt = IrtParsed.from_irt_file(
Path(conf["irt"]["squad"]["dev"]["pyro"][irt_type]["full"]) / "parameters.json"
)
for ex_id in example_ids:
stats = parsed_irt.example_stats[ex_id]
rows.append(
{
"disc": stats.disc,
"diff": stats.diff,
"acc": example_accuracy[ex_id],
"irt": irt_type,
}
)
df = pd.DataFrame(rows)
chart = alt.hconcat()
# for irt_type in ("1PL", "2PL", "3PL"):
for irt_type in ("2PL",):
if irt_type == "1PL":
scatter = (
alt.Chart(df)
.mark_point()
.encode(x=alt.X("acc", title="Accuracy"), y=alt.Y("diff", title="IRT Difficulty"),)
)
else:
scatter = (
alt.Chart(df)
.mark_point()
.encode(
x=alt.X("acc", title="Accuracy"),
y=alt.Y("diff", title="IRT Difficulty"),
color=alt.Color(
"disc",
title="IRT Discriminability",
scale=alt.Scale(scheme="cividis"),
legend=alt.Legend(
# direction="horizontal",
orient="top",
# legendX=930,
# legendY=10,
fillColor="white",
),
),
)
)
if irt_type == "1PL":
title = "1PL (No Discriminability Parameter)"
else:
title = irt_type
scatter = scatter.transform_filter(datum.irt == irt_type).properties(
title=f"IRT Model: {title}",
)
chart |= scatter
chart = chart.resolve_scale(color="independent")
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "irt_acc_dist", filetypes)
else:
save_chart(chart, AUTO_FIG / "irt_acc_dist", filetypes)