def ablation_plot()

in leaderboard/plots.py [0:0]


def ablation_plot(filetypes: List[str], commit: bool = False):
    rows = []
    for (_, irt_type, feature_set), report_path in ABLATION_FILES.items():
        report = read_json(report_path)
        if feature_set in [
            "guids+stats",
            "guids+qwords",
            "topics_10",
            "topics_50",
            "topics_500",
            "topics_100",
            "LM - Title",
        ]:
            continue

        if feature_set == "irt":
            feature_set = "IRT"
        elif feature_set == "guids":
            feature_set = "Subj & Item ID"
        elif feature_set == "ex_id":
            feature_set = "Item ID"
        elif feature_set == "m_id":
            feature_set = "Subject ID"
        elif feature_set == "qwords":
            feature_set = "Question"
        elif feature_set == "cwords":
            feature_set = "Context"
        elif feature_set == "topics_1000":
            feature_set = "Topics 1K"
        elif feature_set == "topics_100":
            feature_set = "Topics 100"
        else:
            feature_set = feature_set.capitalize()

        if feature_set == "All":
            name = "LM All"
        else:
            name = f"LM +{feature_set}"
        rows.append(
            {
                "features": name,
                "irt": irt_type,
                "ROC AUC": report["roc_auc"],
                "Macro F1": report["classification_report"]["macro avg"]["f1-score"],
                # "Macro Precision": report["classification_report"]["macro avg"][
                #     "precision"
                # ],
                # "Macro Recall": report["classification_report"]["macro avg"]["recall"],
                #'weighted_f1': report['classification_report']['weighted avg']['f1-score'],
                #'weighted_precision': report['classification_report']['weighted avg']['precision'],
                #'weighted_recall': report['classification_report']['weighted avg']['recall'],
                "Accuracy": report["classification_report"]["accuracy"],
            }
        )

    df = pd.DataFrame(rows).melt(id_vars=["features", "irt"], var_name="metric")

    IRT_FILES[("multidim", "heldout")] = "data/irt/squad/dev/pyro/multidim_10d_heldout/report.json"
    irt_reports = []
    for (model_type, eval_type), path in IRT_FILES.items():
        report = read_json(path)
        if model_type == "1PL":
            model_type = "Base"
        elif model_type == "2PL":
            model_type = "Disc"
        elif model_type == "3PL":
            model_type = "Feas"
        elif model_type == "multidim":
            model_type = "Vec"
        irt_reports.append(
            {
                "features": f"IRT-{model_type}",
                "irt": model_type,
                "evaluation": eval_type,
                "ROC AUC": report["roc_auc"],
                "Macro F1": report["classification_report"]["macro avg"]["f1-score"],
                # "Macro Precision": report["classification_report"]["macro avg"][
                #    "precision"
                # ],
                # "Macro Recall": report["classification_report"]["macro avg"]["recall"],
                #'weighted_f1': report['classification_report']['weighted avg']['f1-score'],
                #'weighted_precision': report['classification_report']['weighted avg']['precision'],
                #'weighted_recall': report['classification_report']['weighted avg']['recall'],
                "Accuracy": report["classification_report"]["accuracy"],
            }
        )
    report_df = pd.DataFrame(irt_reports)
    report_df = report_df[report_df["evaluation"] == "heldout"]
    report_df = report_df.drop(["evaluation"], axis=1)
    report_df = report_df.melt(id_vars=["features", "irt"], var_name="metric")

    df = pd.concat([report_df, df])
    # cutting the following chart types:
    #  - guid+stats
    #  - guids+qwords
    #  - topics_10
    #  - topics_50
    #  - topics_500

    # to_remove = ["guids+stats", "guids+qwords", "topics_10", "topics_50", "topics_500"]
    # to_remove = [f"LM - {name}" for name in to_remove]
    # df = df[~df["features"].isin(to_remove)]

    group_sizes = df.groupby(["features", "irt", "metric"]).count().value.unique()
    if len(group_sizes) != 1 or 1 not in group_sizes:
        raise ValueError(f"Bad group sizes: {group_sizes}")

    METRIC_SORT_ORDER = [
        "ROC AUC",
        "Macro F1",
        "Macro Precision",
        "Macro Recall",
        "Accuracy",
    ]

    def sort_order(val):
        return METRIC_SORT_ORDER.index(val)

    label_order = [
        "IRT-Vec",
        "IRT-Feas",
        "IRT-Disc",
        "IRT-Base",
        "LM All",
        "LM +IRT",
        "LM +Subj & Item IDs",
        "LM +Item ID",
        "LM +Subject ID",
        "LM +Question",
        "LM +Context",
        "LM +Stats",
    ]
    base = (
        alt.Chart(df)
        .mark_bar()
        .encode(
            color=alt.Color(
                "features",
                title="Features",
                scale=alt.Scale(scheme="category20"),
                sort=label_order,
                legend=alt.Legend(symbolLimit=0, labelLimit=0),
            )
        )
    )

    chart = alt.hconcat()
    first = True
    metric_names = sorted(df.metric.unique(), key=sort_order)
    for metric in metric_names:
        # if first:
        #     title = "Metric Value"
        #     first = False
        # else:
        #     title = ""
        new_chart = (
            base.encode(
                x=alt.X(
                    "features", title="", axis=alt.Axis(labelAngle=-45, labels=False), sort="-y",
                ),
                y=alt.Y("value", title="", scale=alt.Scale(domain=(0, 1))),
            )
            .transform_filter(datum.metric == metric)
            .properties(title=metric)  # , height=100, width=50)
        )
        chart |= new_chart
    font_size = 25
    chart = (
        chart.configure_legend(columns=2, labelFontSize=font_size, titleFontSize=font_size)
        .configure_axis(labelFontSize=font_size, titleFontSize=font_size)
        .configure_header(labelFontSize=font_size, titleFontSize=font_size)
        .configure_title(fontSize=font_size)
        # .configure(padding=0)
    )
    if commit:
        save_chart(chart, COMMIT_AUTO_FIGS / "vw_ablation", filetypes)
    else:
        save_chart(chart, AUTO_FIG / "vw_ablation", filetypes)