def confusion_plot()

in leaderboard/plots.py [0:0]


def confusion_plot(filetypes: List[str], commit: bool = False, irt_model: str = "3PL"):
    dev_irt_params = IrtParsed.from_irt_file(
        Path(conf["irt"]["squad"]["dev"]["pyro"][irt_model]["full"]) / "parameters.json"
    )
    dev_predictions = LeaderboardPredictions.parse_file(
        conf["squad"]["submission_predictions"]["dev"]
    )
    dev_id_to_subject = load_squad_submissions(dev_predictions)
    mapping = read_json(conf["squad"]["dev_to_test"])
    dev_to_test = mapping["dev_to_test"]
    test_irt_params = IrtParsed.from_irt_file(
        Path(conf["irt"]["squad"]["test"]["pyro"][irt_model]["full"]) / "parameters.json"
    )
    subject_df, id_to_subject_stats = create_subject_df(
        dev_id_to_subject=dev_id_to_subject,
        dev_irt_params=dev_irt_params,
        dev_to_test=dev_to_test,
        test_irt_params=test_irt_params,
    )
    item_df = create_item_df(dev_irt_params)

    diff_quantiles = {
        "Low": item_df["diff"].quantile(0.25),
        "Med-Low": item_df["diff"].quantile(0.5),
        "Med-High": item_df["diff"].quantile(0.75),
        "High": item_df["diff"].quantile(1),
    }
    disc_quantiles = {
        "Low": item_df["disc"].quantile(0.25),
        "Med-Low": item_df["disc"].quantile(0.5),
        "Med-High": item_df["disc"].quantile(0.75),
        "High": item_df["disc"].quantile(1),
    }
    df = create_confusion_df(
        dev_predictions=dev_predictions,
        dev_irt_params=dev_irt_params,
        diff_quantiles=diff_quantiles,
        disc_quantiles=disc_quantiles,
    )
    by_diff = df.groupby(["subject_id", "diff_cat"]).mean("n").reset_index()
    by_diff["category"] = by_diff["diff_cat"]
    by_diff["parameter"] = "Difficulty"

    by_disc = df.groupby(["subject_id", "disc_cat"]).mean("n").reset_index()
    by_disc["category"] = by_disc["disc_cat"]
    by_disc["parameter"] = "Discriminability"

    combined = pd.concat([by_diff, by_disc])

    def format_name(name):
        parens = name.split("(")
        return parens[0]

    combined["name"] = combined["subject_id"].map(
        lambda sid: format_name(id_to_subject_stats[sid]["name"])
    )
    combined["test_em"] = combined["subject_id"].map(
        lambda sid: id_to_subject_stats[sid]["test_em"]
    )
    combined["percent"] = combined["response"].map(lambda p: round(100 * p))

    selected_subject_ids = {
        # overfitting dev_em
        # "0x7a3431e690444afca4988f927ad23019": ["Overfit"],
        "0xa5265f8dbc424109a4573494c113235d": ["Overfit"],
        # top dev_skill
        # "0x8978eb3bd032447a80f27b2b82ad3b80": [
        #     "Top Dev Ability",
        #     "Top Test EM",
        # ],  # also top test_em
        # "0xe56a3accea374f9787255a85febd8497": ["Top Model"],
        # top  test_em
        # "0xc81e2e3395dd447eb85c899aa93d0d16": ["Top Model"],
        # "0x082fc49949b14c6aa3827bfebed5cc40": ["Top Model"],
        # top test skill
        "0x8978eb3bd032447a80f27b2b82ad3b80": ["Top Model"],
        # lowest dev_skill
        "0x8a3b01f4ded748df8b657684212372b4": ["Bottom Model"],
        # "0x2d5cf8f56e164de8837cb8ed30f15f59": ["Bottom Model"],
        # "0xfe18a19d54d44e2eaefd68836c3b388b": ["Hallmark"],
        # "0xfcd2efa17551478f96c593fe07eebd97": ["Hallmark"],
        "0xeb6fe173849a495b83eb4e56b172e02a": ["Hallmark"],
        "0x843f0d9f242f46b9803558614bff2f86": ["Hallmark"],
    }
    combined_filtered = combined[combined["subject_id"].map(lambda x: x in selected_subject_ids)]

    label_rows = []
    for sid, labels in selected_subject_ids.items():
        for l in labels:
            label_rows.append(
                {
                    "subject_id": sid,
                    "name": format_name(id_to_subject_stats[sid]["name"]),
                    "label": l,
                    "n": 1,
                    "test_em": id_to_subject_stats[sid]["test_em"],
                    "test_em_percent": round(id_to_subject_stats[sid]["test_em"]),
                }
            )
    label_df = pd.DataFrame(label_rows)
    subject_order = label_df.sort_values("test_em", ascending=False).name.tolist()

    label_chart = (
        alt.Chart(label_df)
        .mark_bar()
        .encode(
            x=alt.X(
                "n", title="", axis=alt.Axis(labels=False, grid=False, ticks=False, domain=False),
            ),
            y=alt.Y(
                "name",
                title="Name",
                sort=subject_order,
                axis=alt.Axis(labels=True, ticks=True, orient="left"),
            ),
            color=alt.Color(
                "label",
                title="Description",
                legend=alt.Legend(
                    orient="left", offset=0,  # , direction="vertical", legendX=-250, legendY=50
                ),
                scale=alt.Scale(scheme="set2"),
            ),
        )
    )
    label_text = (
        alt.Chart(label_df)
        .mark_text(baseline="middle", dx=-12, color="black")
        .encode(
            x=alt.X("n", title="Test Acc", axis=alt.Axis(orient="top")),
            y=alt.Y("name", sort=subject_order),
            text=alt.Text("test_em_percent"),
        )
    )
    label_chart = (label_chart + label_text).properties(width=25)

    main_chart = (
        alt.Chart(combined_filtered)
        .mark_rect(stroke="black", strokeWidth=0.1)
        .encode(
            x=alt.X("category", title="", sort=order, axis=alt.Axis(labelAngle=-35)),
            y=alt.Y(
                "name",
                title="",
                sort=subject_order,
                axis=alt.Axis(titlePadding=10, labels=False),
                scale=alt.Scale(paddingInner=0.2),
            ),
            color=alt.Color(
                "response",
                title="Dev Acc",
                scale=alt.Scale(scheme="magma"),
                legend=alt.Legend(offset=0),
            ),
        )
    )
    text = (
        alt.Chart(combined_filtered)
        .mark_text(baseline="middle")
        .encode(
            x=alt.X("category", sort=order),
            y=alt.Y("name", sort=subject_order),
            text=alt.Text("percent"),
            color=alt.condition(alt.datum.percent > 70, alt.value("black"), alt.value("white")),
        )
    )

    chart = main_chart + text
    hcat = alt.hconcat()
    for param in ["Difficulty", "Discriminability"]:
        hcat |= chart.transform_filter(datum.parameter == param).properties(title=param)
    chart = alt.hconcat(label_chart, hcat, spacing=0).configure(padding=0)

    if commit:
        save_chart(chart, COMMIT_AUTO_FIGS / "irt_confusion", filetypes)
    else:
        save_chart(chart, AUTO_FIG / "irt_confusion", filetypes)