def plot_irt_acc()

in leaderboard/plots.py [0:0]


def plot_irt_acc(filetypes: List[str], commit: bool = False):
    base_parsed_irt = IrtParsed.from_irt_file(
        Path(conf["irt"]["squad"]["dev"]["pyro"]["3PL"]["full"]) / "parameters.json"
    )
    example_ids = base_parsed_irt.example_ids

    data = LeaderboardPredictions.parse_file(conf["squad"]["submission_predictions"]["dev"])
    example_accuracy = defaultdict(float)
    for model_scores in data.scored_predictions.values():
        for ex_id in example_ids:
            example_accuracy[ex_id] += model_scores["exact_match"][ex_id]
    n_models = len(data.scored_predictions)
    for ex_id, correct in example_accuracy.items():
        example_accuracy[ex_id] = correct / n_models
    rows = []
    for irt_type in ("1PL", "2PL", "3PL"):
        parsed_irt = IrtParsed.from_irt_file(
            Path(conf["irt"]["squad"]["dev"]["pyro"][irt_type]["full"]) / "parameters.json"
        )

        for ex_id in example_ids:
            stats = parsed_irt.example_stats[ex_id]
            rows.append(
                {
                    "disc": stats.disc,
                    "diff": stats.diff,
                    "acc": example_accuracy[ex_id],
                    "irt": irt_type,
                }
            )
    df = pd.DataFrame(rows)
    chart = alt.hconcat()

    # for irt_type in ("1PL", "2PL", "3PL"):
    for irt_type in ("2PL",):
        if irt_type == "1PL":
            scatter = (
                alt.Chart(df)
                .mark_point()
                .encode(x=alt.X("acc", title="Accuracy"), y=alt.Y("diff", title="IRT Difficulty"),)
            )
        else:
            scatter = (
                alt.Chart(df)
                .mark_point()
                .encode(
                    x=alt.X("acc", title="Accuracy"),
                    y=alt.Y("diff", title="IRT Difficulty"),
                    color=alt.Color(
                        "disc",
                        title="IRT Discriminability",
                        scale=alt.Scale(scheme="cividis"),
                        legend=alt.Legend(
                            # direction="horizontal",
                            orient="top",
                            # legendX=930,
                            # legendY=10,
                            fillColor="white",
                        ),
                    ),
                )
            )
        if irt_type == "1PL":
            title = "1PL (No Discriminability Parameter)"
        else:
            title = irt_type
        scatter = scatter.transform_filter(datum.irt == irt_type).properties(
            title=f"IRT Model: {title}",
        )
        chart |= scatter
    chart = chart.resolve_scale(color="independent")
    if commit:
        save_chart(chart, COMMIT_AUTO_FIGS / "irt_acc_dist", filetypes)
    else:
        save_chart(chart, AUTO_FIG / "irt_acc_dist", filetypes)