in scripts/plot_summary_comparison.py [0:0]
def plot_summary_comparison(paths: List[str]) -> None:
"""Given a list of paths to output directories produced by e.g. `scripts/setfit/run_fewshot.py`,
produce and save boxplots that compare the various results.
The plots are saved to scripts/images/v_{id}/{dataset}.png, i.e. one plot per dataset.
Args:
paths (List[str]): List of paths to output directories, generally
`scripts/{method_name}/results/{model_name}`
"""
# Parse the result paths
dataset_to_df = defaultdict(pd.DataFrame)
dataset_to_metric = {}
for path_index, path in enumerate(paths):
ds_to_metric, this_dataset_to_df = get_summary_df(path)
for dataset, df in this_dataset_to_df.items():
df["path_index"] = path_index
dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df))
dataset_to_metric = dataset_to_metric | ds_to_metric
# Prepare folder for storing figures
image_dir = Path("scripts") / "images"
image_dir.mkdir(exist_ok=True)
new_version = (
max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1
)
output_dir = image_dir / f"v_{new_version}"
output_dir.mkdir()
# Save a copy the executed command in output directory
(output_dir / "command.txt").write_text("python " + " ".join(sys.argv))
# Create the plots per each dataset
for dataset, df in dataset_to_df.items():
columns = [column for column in df.columns if not column.startswith("path")]
fig, axes = plt.subplots(ncols=len(columns), sharey=True)
for column_index, column in enumerate(columns):
ax = axes[column_index] if len(columns) > 1 else axes
# Set the y label only for the first column
if column_index == 0:
ax.set_ylabel(dataset_to_metric[dataset])
# Set positions to 0, 0.25, ..., one position per boxplot
# This places the boxplots closer together
n_boxplots = len(df["path_index"].unique())
allotted_box_width = 0.2
positions = [allotted_box_width * i for i in range(n_boxplots)]
ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25))
df[[column, "path_index"]].groupby("path_index", sort=True).boxplot(
subplots=False, ax=ax, column=column, positions=positions
)
k_shot = column.split("-")[-1]
ax.set_xlabel(f"{k_shot}-shot")
if n_boxplots > 1:
# If there are multiple boxplots, override the labels at the bottom generated by pandas
if n_boxplots <= 26:
ax.set_xticklabels(string.ascii_uppercase[:n_boxplots])
else:
ax.set_xticklabels(range(n_boxplots))
else:
# Otherwise, just remove the xticks
ax.tick_params(labelbottom=False)
if n_boxplots > 1:
fig.suptitle(
f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions"
)
else:
fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions")
fig.tight_layout()
plt.savefig(str(output_dir / dataset))