def create_irt_dist_chart()

in leaderboard/visualize.py [0:0]


def create_irt_dist_chart(irt_results: IrtParsed):
    BASE_SIZE = 300
    rows = []
    squad = load_cached_squad()
    item_em, item_f1 = get_item_accuracy()
    for example in irt_results.example_stats.values():
        item_id = example.example_id
        item = squad[item_id]
        rows.append(
            {
                "disc": example.disc,
                "diff": example.diff,
                "lambda": example.lambda_,
                "avg EM": item_em[item_id],
                "avg F1": item_f1[item_id],
                "item_id": item_id,
                "title": item["title"],
                "question": item["text"],
                "is_impossible": item["is_impossible"],
                "answer": item["answers"],
                "context": item["context"],
            }
        )

    def assign_feas_bin(feas):
        if feas < 0.33:
            return "Low"
        elif feas < 0.66:
            return "Mid"
        else:
            return "High"

    df = pd.DataFrame(rows)
    df["feas_bin"] = df["lambda"].map(assign_feas_bin)
    diff_min = np.floor(df["diff"].min())
    diff_max = np.ceil(df["diff"].max())
    diff_scale = alt.Scale(domain=(diff_min, diff_max))

    disc_min = np.floor(df["disc"].min())
    disc_max = np.ceil(df["disc"].max())
    disc_scale = alt.Scale(domain=(disc_min, disc_max))

    ratio = 1.5
    points = (
        alt.Chart(df)
        .mark_point()
        .encode(
            x=alt.X("diff", title="Difficulty (𝜃)", scale=diff_scale),
            y=alt.Y("disc", title="Discriminability (𝛾)", scale=disc_scale),
            color=alt.Color(
                "lambda", title="Feasibility (λ)", scale=alt.Scale(scheme="redyellowblue"),
            ),
            tooltip=alt.Tooltip(
                [
                    "item_id",
                    "diff",
                    "disc",
                    "lambda",
                    "avg EM",
                    "avg F1",
                    "title",
                    "question",
                    "is_impossible",
                    "answer",
                ]
            ),
        )
    ).properties(width=ratio * BASE_SIZE, height=ratio * BASE_SIZE)
    top_hist = (
        alt.Chart(df)
        .mark_area()
        .encode(
            x=alt.X("diff", bin=alt.Bin(maxbins=50), stack=None, title="", scale=diff_scale),
            y=alt.Y("count()", stack=True, title=""),
        )
        .properties(height=40, width=ratio * BASE_SIZE)
    )
    right_hist = (
        alt.Chart(df)
        .mark_area()
        .encode(
            x=alt.X("count()", stack=True, title=""),
            y=alt.Y("disc", bin=alt.Bin(maxbins=50), stack=None, title="", scale=disc_scale),
        )
    ).properties(width=40, height=ratio * BASE_SIZE)
    # points = points + annotations
    chart = top_hist & (points | right_hist)
    chart = chart.configure_concat(spacing=10)

    return chart