in leaderboard/plots.py [0:0]
def plot_irt_comparison(filetypes: List[str], commit: bool = False):
irt_reports = []
for (model_type, eval_type), path in IRT_FILES.items():
report = read_json(path)
irt_reports.append(
{
"model": model_type,
"evaluation": eval_type,
"ROC AUC": report["roc_auc"],
"Macro F1": report["classification_report"]["macro avg"]["f1-score"],
# "Macro Precision": report["classification_report"]["macro avg"][
# "precision"
# ],
# "Macro Recall": report["classification_report"]["macro avg"]["recall"],
#'weighted_f1': report['classification_report']['weighted avg']['f1-score'],
#'weighted_precision': report['classification_report']['weighted avg']['precision'],
#'weighted_recall': report['classification_report']['weighted avg']['recall'],
"Accuracy": report["classification_report"]["accuracy"],
}
)
report_df = pd.DataFrame(irt_reports)
def to_precision_numbers(num, places):
if isinstance(num, str):
return r"\text{" + num + r"}"
else:
return to_precision(num, places)
latex_out = (
report_df[report_df.evaluation == "heldout"]
.applymap(lambda n: f"${to_precision_numbers(n, 3)}$")
.pivot(index="model", columns="evaluation")
.reset_index()
.to_latex(index=False, escape=False)
)
print(latex_out)
df = report_df.melt(id_vars=["model", "evaluation"], var_name="metric")
METRIC_SORT_ORDER = [
"ROC AUC",
"Macro F1",
"Macro Precision",
"Macro Recall",
"Accuracy",
]
heldout_df = df[df.evaluation == "heldout"]
bars = (
alt.Chart()
.mark_bar()
.encode(
color=alt.Color(
"model",
title="IRT Model",
scale=alt.Scale(scheme="category10"),
legend=alt.Legend(orient="top"),
),
x=alt.X("model", title="", axis=alt.Axis(labels=False), sort=METRIC_SORT_ORDER,),
y=alt.Y("value", title="Heldout Metric", scale=alt.Scale(zero=False, domain=[0.8, 1]),),
tooltip="value",
)
.properties(width=100, height=150)
)
font_size = 18
text = bars.mark_text(align="center", baseline="middle", dy=-7, fontSize=14).encode(
text=alt.Text("value:Q", format=".2r"), color=alt.value("black")
)
chart = (
alt.layer(bars, text, data=heldout_df)
.facet(column=alt.Column("metric", title=""))
.configure_axis(labelFontSize=font_size, titleFontSize=font_size)
.configure_legend(labelFontSize=font_size, titleFontSize=font_size)
.configure_header(labelFontSize=font_size)
)
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "irt_model_comparison", filetypes)
else:
save_chart(chart, AUTO_FIG / "irt_model_comparison", filetypes)