def plot_summary_comparison()

in scripts/plot_summary_comparison.py [0:0]


def plot_summary_comparison(paths: List[str]) -> None:
    """Given a list of paths to output directories produced by e.g. `scripts/setfit/run_fewshot.py`,
    produce and save boxplots that compare the various results.

    The plots are saved to scripts/images/v_{id}/{dataset}.png, i.e. one plot per dataset.

    Args:
        paths (List[str]): List of paths to output directories, generally
            `scripts/{method_name}/results/{model_name}`
    """

    # Parse the result paths
    dataset_to_df = defaultdict(pd.DataFrame)
    dataset_to_metric = {}
    for path_index, path in enumerate(paths):
        ds_to_metric, this_dataset_to_df = get_summary_df(path)
        for dataset, df in this_dataset_to_df.items():
            df["path_index"] = path_index
            dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df))
        dataset_to_metric = dataset_to_metric | ds_to_metric

    # Prepare folder for storing figures
    image_dir = Path("scripts") / "images"
    image_dir.mkdir(exist_ok=True)
    new_version = (
        max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1
    )
    output_dir = image_dir / f"v_{new_version}"
    output_dir.mkdir()

    # Save a copy the executed command in output directory
    (output_dir / "command.txt").write_text("python " + " ".join(sys.argv))

    # Create the plots per each dataset
    for dataset, df in dataset_to_df.items():
        columns = [column for column in df.columns if not column.startswith("path")]
        fig, axes = plt.subplots(ncols=len(columns), sharey=True)
        for column_index, column in enumerate(columns):
            ax = axes[column_index] if len(columns) > 1 else axes

            # Set the y label only for the first column
            if column_index == 0:
                ax.set_ylabel(dataset_to_metric[dataset])

            # Set positions to 0, 0.25, ..., one position per boxplot
            # This places the boxplots closer together
            n_boxplots = len(df["path_index"].unique())
            allotted_box_width = 0.2
            positions = [allotted_box_width * i for i in range(n_boxplots)]
            ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25))

            df[[column, "path_index"]].groupby("path_index", sort=True).boxplot(
                subplots=False, ax=ax, column=column, positions=positions
            )

            k_shot = column.split("-")[-1]
            ax.set_xlabel(f"{k_shot}-shot")
            if n_boxplots > 1:
                # If there are multiple boxplots, override the labels at the bottom generated by pandas
                if n_boxplots <= 26:
                    ax.set_xticklabels(string.ascii_uppercase[:n_boxplots])
                else:
                    ax.set_xticklabels(range(n_boxplots))
            else:
                # Otherwise, just remove the xticks
                ax.tick_params(labelbottom=False)

        if n_boxplots > 1:
            fig.suptitle(
                f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions"
            )
        else:
            fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions")
        fig.tight_layout()
        plt.savefig(str(output_dir / dataset))