in leaderboard/visualize.py [0:0]
def create_irt_dist_chart(irt_results: IrtParsed):
BASE_SIZE = 300
rows = []
squad = load_cached_squad()
item_em, item_f1 = get_item_accuracy()
for example in irt_results.example_stats.values():
item_id = example.example_id
item = squad[item_id]
rows.append(
{
"disc": example.disc,
"diff": example.diff,
"lambda": example.lambda_,
"avg EM": item_em[item_id],
"avg F1": item_f1[item_id],
"item_id": item_id,
"title": item["title"],
"question": item["text"],
"is_impossible": item["is_impossible"],
"answer": item["answers"],
"context": item["context"],
}
)
def assign_feas_bin(feas):
if feas < 0.33:
return "Low"
elif feas < 0.66:
return "Mid"
else:
return "High"
df = pd.DataFrame(rows)
df["feas_bin"] = df["lambda"].map(assign_feas_bin)
diff_min = np.floor(df["diff"].min())
diff_max = np.ceil(df["diff"].max())
diff_scale = alt.Scale(domain=(diff_min, diff_max))
disc_min = np.floor(df["disc"].min())
disc_max = np.ceil(df["disc"].max())
disc_scale = alt.Scale(domain=(disc_min, disc_max))
ratio = 1.5
points = (
alt.Chart(df)
.mark_point()
.encode(
x=alt.X("diff", title="Difficulty (𝜃)", scale=diff_scale),
y=alt.Y("disc", title="Discriminability (𝛾)", scale=disc_scale),
color=alt.Color(
"lambda", title="Feasibility (λ)", scale=alt.Scale(scheme="redyellowblue"),
),
tooltip=alt.Tooltip(
[
"item_id",
"diff",
"disc",
"lambda",
"avg EM",
"avg F1",
"title",
"question",
"is_impossible",
"answer",
]
),
)
).properties(width=ratio * BASE_SIZE, height=ratio * BASE_SIZE)
top_hist = (
alt.Chart(df)
.mark_area()
.encode(
x=alt.X("diff", bin=alt.Bin(maxbins=50), stack=None, title="", scale=diff_scale),
y=alt.Y("count()", stack=True, title=""),
)
.properties(height=40, width=ratio * BASE_SIZE)
)
right_hist = (
alt.Chart(df)
.mark_area()
.encode(
x=alt.X("count()", stack=True, title=""),
y=alt.Y("disc", bin=alt.Bin(maxbins=50), stack=None, title="", scale=disc_scale),
)
).properties(width=40, height=ratio * BASE_SIZE)
# points = points + annotations
chart = top_hist & (points | right_hist)
chart = chart.configure_concat(spacing=10)
return chart