in leaderboard/plots.py [0:0]
def ablation_plot(filetypes: List[str], commit: bool = False):
rows = []
for (_, irt_type, feature_set), report_path in ABLATION_FILES.items():
report = read_json(report_path)
if feature_set in [
"guids+stats",
"guids+qwords",
"topics_10",
"topics_50",
"topics_500",
"topics_100",
"LM - Title",
]:
continue
if feature_set == "irt":
feature_set = "IRT"
elif feature_set == "guids":
feature_set = "Subj & Item ID"
elif feature_set == "ex_id":
feature_set = "Item ID"
elif feature_set == "m_id":
feature_set = "Subject ID"
elif feature_set == "qwords":
feature_set = "Question"
elif feature_set == "cwords":
feature_set = "Context"
elif feature_set == "topics_1000":
feature_set = "Topics 1K"
elif feature_set == "topics_100":
feature_set = "Topics 100"
else:
feature_set = feature_set.capitalize()
if feature_set == "All":
name = "LM All"
else:
name = f"LM +{feature_set}"
rows.append(
{
"features": name,
"irt": irt_type,
"ROC AUC": report["roc_auc"],
"Macro F1": report["classification_report"]["macro avg"]["f1-score"],
# "Macro Precision": report["classification_report"]["macro avg"][
# "precision"
# ],
# "Macro Recall": report["classification_report"]["macro avg"]["recall"],
#'weighted_f1': report['classification_report']['weighted avg']['f1-score'],
#'weighted_precision': report['classification_report']['weighted avg']['precision'],
#'weighted_recall': report['classification_report']['weighted avg']['recall'],
"Accuracy": report["classification_report"]["accuracy"],
}
)
df = pd.DataFrame(rows).melt(id_vars=["features", "irt"], var_name="metric")
IRT_FILES[("multidim", "heldout")] = "data/irt/squad/dev/pyro/multidim_10d_heldout/report.json"
irt_reports = []
for (model_type, eval_type), path in IRT_FILES.items():
report = read_json(path)
if model_type == "1PL":
model_type = "Base"
elif model_type == "2PL":
model_type = "Disc"
elif model_type == "3PL":
model_type = "Feas"
elif model_type == "multidim":
model_type = "Vec"
irt_reports.append(
{
"features": f"IRT-{model_type}",
"irt": model_type,
"evaluation": eval_type,
"ROC AUC": report["roc_auc"],
"Macro F1": report["classification_report"]["macro avg"]["f1-score"],
# "Macro Precision": report["classification_report"]["macro avg"][
# "precision"
# ],
# "Macro Recall": report["classification_report"]["macro avg"]["recall"],
#'weighted_f1': report['classification_report']['weighted avg']['f1-score'],
#'weighted_precision': report['classification_report']['weighted avg']['precision'],
#'weighted_recall': report['classification_report']['weighted avg']['recall'],
"Accuracy": report["classification_report"]["accuracy"],
}
)
report_df = pd.DataFrame(irt_reports)
report_df = report_df[report_df["evaluation"] == "heldout"]
report_df = report_df.drop(["evaluation"], axis=1)
report_df = report_df.melt(id_vars=["features", "irt"], var_name="metric")
df = pd.concat([report_df, df])
# cutting the following chart types:
# - guid+stats
# - guids+qwords
# - topics_10
# - topics_50
# - topics_500
# to_remove = ["guids+stats", "guids+qwords", "topics_10", "topics_50", "topics_500"]
# to_remove = [f"LM - {name}" for name in to_remove]
# df = df[~df["features"].isin(to_remove)]
group_sizes = df.groupby(["features", "irt", "metric"]).count().value.unique()
if len(group_sizes) != 1 or 1 not in group_sizes:
raise ValueError(f"Bad group sizes: {group_sizes}")
METRIC_SORT_ORDER = [
"ROC AUC",
"Macro F1",
"Macro Precision",
"Macro Recall",
"Accuracy",
]
def sort_order(val):
return METRIC_SORT_ORDER.index(val)
label_order = [
"IRT-Vec",
"IRT-Feas",
"IRT-Disc",
"IRT-Base",
"LM All",
"LM +IRT",
"LM +Subj & Item IDs",
"LM +Item ID",
"LM +Subject ID",
"LM +Question",
"LM +Context",
"LM +Stats",
]
base = (
alt.Chart(df)
.mark_bar()
.encode(
color=alt.Color(
"features",
title="Features",
scale=alt.Scale(scheme="category20"),
sort=label_order,
legend=alt.Legend(symbolLimit=0, labelLimit=0),
)
)
)
chart = alt.hconcat()
first = True
metric_names = sorted(df.metric.unique(), key=sort_order)
for metric in metric_names:
# if first:
# title = "Metric Value"
# first = False
# else:
# title = ""
new_chart = (
base.encode(
x=alt.X(
"features", title="", axis=alt.Axis(labelAngle=-45, labels=False), sort="-y",
),
y=alt.Y("value", title="", scale=alt.Scale(domain=(0, 1))),
)
.transform_filter(datum.metric == metric)
.properties(title=metric) # , height=100, width=50)
)
chart |= new_chart
font_size = 25
chart = (
chart.configure_legend(columns=2, labelFontSize=font_size, titleFontSize=font_size)
.configure_axis(labelFontSize=font_size, titleFontSize=font_size)
.configure_header(labelFontSize=font_size, titleFontSize=font_size)
.configure_title(fontSize=font_size)
# .configure(padding=0)
)
if commit:
save_chart(chart, COMMIT_AUTO_FIGS / "vw_ablation", filetypes)
else:
save_chart(chart, AUTO_FIG / "vw_ablation", filetypes)