prepare_results.py (81 lines of code) (raw):
import argparse
import glob
import os
import sys
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from huggingface_hub import upload_file
sys.path.append(".")
from utils.benchmarking_utils import collate_csv # noqa: E402
REPO_ID = "sayakpaul/sample-datasets"
def prepare_plot(df, args):
# Drop the columns that are not needed
columns_to_drop = [
"batch_size",
"num_inference_steps",
"pipeline_cls",
"ckpt_id",
"upcast_vae",
"memory (gbs)",
"actual_gpu_memory (gbs)",
"tag",
]
df_filtered = df.drop(columns=columns_to_drop)
df_filtered[["quant"]] = df_filtered[["do_quant"]].fillna("None")
df_filtered.drop(columns=["do_quant"], inplace=True)
# Create a new column to consolidate settings into a readable format
df_filtered["settings"] = df_filtered.apply(
lambda row: ", ".join([f"{col}-{row[col]}" for col in df_filtered.columns if col != "time (secs)"]), axis=1
)
df_filtered["formatted_settings"] = df_filtered["settings"].str.replace(", ", "\n", regex=False)
df_filtered.loc[0, "formatted_settings"] = "default"
# Generating the plot with matplotlib directly for better control
plt.figure(figsize=(12, 10))
sns.set_style("whitegrid")
# Calculate the number of unique settings for bar positions
n_settings = len(df_filtered["formatted_settings"].unique())
bar_positions = range(n_settings)
# Choose a color palette
palette = sns.color_palette("husl", n_settings)
# Plot each bar manually
bar_width = 0.25 # Width of the bars
for i, setting in enumerate(df_filtered["formatted_settings"].unique()):
# Filter the dataframe for each setting and get the mean time
mean_time = df_filtered[df_filtered["formatted_settings"] == setting]["time (secs)"].mean()
plt.bar(i, mean_time, width=bar_width, align="center", color=palette[i])
# Add the text above the bars
plt.text(i, mean_time + 0.01, f"{mean_time:.2f}", ha="center", va="bottom", fontsize=14, fontweight="bold")
# Set the x-ticks to correspond to the settings
plt.xticks(bar_positions, df_filtered["formatted_settings"].unique(), rotation=45, ha="right", fontsize=10)
plt.ylabel("Time in Seconds", fontsize=14, labelpad=15)
plt.xlabel("Settings", fontsize=14, labelpad=15)
plt.title(args.plot_title, fontsize=18, fontweight="bold", pad=20)
# Adding horizontal gridlines for better readability
plt.grid(axis="y", linestyle="--", linewidth=0.7, alpha=0.7)
plt.tight_layout()
plt.subplots_adjust(top=0.9, bottom=0.2) # Adjust the top and bottom
plot_path = args.plot_title.replace(" ", "_") + ".png"
plt.savefig(plot_path, bbox_inches="tight", dpi=300)
if args.push_to_hub:
upload_file(repo_id=REPO_ID, path_in_repo=plot_path, path_or_fileobj=plot_path, repo_type="dataset")
print(
f"Plot successfully uploaded. Find it here: https://huggingface.co/datasets/{REPO_ID}/blob/main/{args.plot_file_path}"
)
# Show the plot
plt.show()
def main(args):
all_csvs = sorted(glob.glob(f"{args.base_path}/*.csv"))
all_csvs = [os.path.join(args.base_path, x) for x in all_csvs]
is_pixart = "PixArt-alpha" in all_csvs[0]
collate_csv(all_csvs, args.final_csv_filename, is_pixart=is_pixart)
if args.push_to_hub:
upload_file(
repo_id=REPO_ID,
path_in_repo=args.final_csv_filename,
path_or_fileobj=args.final_csv_filename,
repo_type="dataset",
)
print(
f"CSV successfully uploaded. Find it here: https://huggingface.co/datasets/{REPO_ID}/blob/main/{args.final_csv_filename}"
)
if args.plot_title is not None:
df = pd.read_csv(args.final_csv_filename)
prepare_plot(df, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base_path", type=str, default=".")
parser.add_argument("--final_csv_filename", type=str, default="collated_results.csv")
parser.add_argument("--plot_title", type=str, default=None)
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()
main(args)