in scripts/create_summary_table.py [0:0]
def create_summary_table(results_path: str) -> None:
"""Given per-split results, creates a summary table of all datasets,
with average metrics and standard deviations.
Args:
path: path to per-split results: either `scripts/{method_name}/{results}/{model_name}`,
or `final_results/{method_name}/{model_name}.tar.gz`
"""
if results_path.endswith("tar.gz"):
unzipped_path = extract_results(results_path)
else:
unzipped_path = results_path
if "tfew" in unzipped_path:
print("Computing medians for T-Few...")
compute_tfew_medians(unzipped_path)
sample_sizes = get_sample_sizes(unzipped_path)
header_row = ["dataset", "measure"]
for sample_size in sample_sizes:
header_row.append(f"{sample_size}_avg")
header_row.append(f"{sample_size}_std")
csv_lines = [header_row]
means, stds = defaultdict(list), defaultdict(list)
for dataset in next(os.walk(unzipped_path))[1]:
metric_name, formatted_metrics, exact_metrics, exact_stds, sample_sizes = get_formatted_ds_metrics(
unzipped_path, dataset, sample_sizes
)
dataset_row = [dataset, metric_name, *formatted_metrics]
csv_lines.append(dataset_row)
# Collect exact metrics for overall average and std calculation
for sample_size in sample_sizes:
means[sample_size].append(exact_metrics[sample_size])
stds[sample_size].append(exact_stds[sample_size])
# Generate row for overall average
formatted_average_row = []
for sample_size in sample_sizes:
overall_average = mean(means[sample_size])
overall_std = mean(stds[sample_size])
formatted_average_row.extend([f"{overall_average:.1f}", f"{overall_std:.1f}"])
csv_lines.append(["Average", "N/A", *formatted_average_row])
output_path = os.path.join(unzipped_path, "summary_table.csv")
print("=" * 80)
print("Summary table:\n")
with open(output_path, "w") as f:
for line in csv_lines:
f.write(",".join(line) + "\n")
print(", ".join(line))
print("=" * 80)
print(f"Saved summary table to {output_path}")