def create_irt_dist_chart()

in leaderboard/plots.py [0:0]


def create_irt_dist_chart(irt_results: IrtParsed):
    rows = []
    for example in irt_results.example_stats.values():
        rows.append(
            {
                "disc": example.disc,
                "diff": example.diff,
                "lambda": example.lambda_,
                "item_id": example.example_id,
            }
        )

    def assign_feas_bin(feas):
        if feas < 0.33:
            return "Low"
        elif feas < 0.66:
            return "Mid"
        else:
            return "High"

    df = pd.DataFrame(rows)
    df["feas_bin"] = df["lambda"].map(assign_feas_bin)
    diff_min = np.floor(df["diff"].min())
    diff_max = np.ceil(df["diff"].max())
    diff_scale = alt.Scale(domain=(diff_min, diff_max))

    disc_min = np.floor(df["disc"].min())
    disc_max = np.ceil(df["disc"].max())
    disc_scale = alt.Scale(domain=(disc_min, disc_max))

    ratio = 1.5
    points = (
        alt.Chart(df)
        .mark_point(filled=True)
        .encode(
            x=alt.X("diff", title="Difficulty (𝜃)", scale=diff_scale),
            y=alt.Y("disc", title="Discriminability (𝛾)", scale=disc_scale),
            color=alt.Color(
                "lambda",
                title="Feasibility (λ)",
                scale=alt.Scale(scheme="redyellowblue"),
                legend=alt.Legend(
                    direction="horizontal",
                    orient="none",
                    legendX=240,
                    legendY=0,
                    gradientLength=80,
                ),
            ),
            size=alt.value(3),
            tooltip=alt.Tooltip(["item_id", "diff", "disc", "lambda"]),
        )
    ).properties(width=ratio * BASE_SIZE, height=ratio * BASE_SIZE)
    top_hist = (
        alt.Chart(df)
        .mark_area()
        .encode(
            x=alt.X("diff", bin=alt.Bin(maxbins=50), stack=None, title="", scale=diff_scale),
            y=alt.Y("count()", stack=True, title=""),
        )
        .properties(height=40, width=ratio * BASE_SIZE)
    )
    right_hist = (
        alt.Chart(df)
        .mark_area()
        .encode(
            x=alt.X("count()", stack=True, title=""),
            y=alt.Y("disc", bin=alt.Bin(maxbins=50), stack=None, title="", scale=disc_scale),
        )
    ).properties(width=40, height=ratio * BASE_SIZE)
    annotations = (
        alt.Chart(pd.DataFrame([{"text": "Annotation|Error", "x": -4, "y": -6}]))
        .mark_text(lineBreak="|", align="center")
        .encode(x="x", y="y", text="text")
    )
    # points = points + annotations
    chart = top_hist & (points | right_hist)
    chart = chart.configure_concat(spacing=10)

    base = alt.Chart(df)
    base = (
        alt.Chart(df)
        .transform_joinaggregate(total="count(*)")
        .transform_calculate(pct="1 / datum.total")
        .encode(
            x=alt.X("lambda", title="Probability of Feasibility (λ)", bin=alt.Bin(maxbins=49),),
        )
    )
    counts = base.mark_bar().encode(
        y=alt.Y(
            "count()", title="Count", scale=alt.Scale(type="log"), axis=alt.Axis(orient="left"),
        )
    )
    pcts = base.mark_bar().encode(
        y=alt.Y(
            "sum(pct):Q",
            title="Percentage",
            scale=alt.Scale(type="log"),
            axis=alt.Axis(orient="right", format="%"),
        )
    )
    lambda_dist_chart = alt.layer(counts, pcts).resolve_scale(y="independent")
    return chart, lambda_dist_chart