def prepare_plot()

in prepare_results.py [0:0]


def prepare_plot(df, args):
    # Drop the columns that are not needed
    columns_to_drop = [
        "batch_size",
        "num_inference_steps",
        "pipeline_cls",
        "ckpt_id",
        "upcast_vae",
        "memory (gbs)",
        "actual_gpu_memory (gbs)",
        "tag",
    ]
    df_filtered = df.drop(columns=columns_to_drop)
    df_filtered[["quant"]] = df_filtered[["do_quant"]].fillna("None")
    df_filtered.drop(columns=["do_quant"], inplace=True)

    # Create a new column to consolidate settings into a readable format
    df_filtered["settings"] = df_filtered.apply(
        lambda row: ", ".join([f"{col}-{row[col]}" for col in df_filtered.columns if col != "time (secs)"]), axis=1
    )
    df_filtered["formatted_settings"] = df_filtered["settings"].str.replace(", ", "\n", regex=False)
    df_filtered.loc[0, "formatted_settings"] = "default"

    # Generating the plot with matplotlib directly for better control
    plt.figure(figsize=(12, 10))
    sns.set_style("whitegrid")

    # Calculate the number of unique settings for bar positions
    n_settings = len(df_filtered["formatted_settings"].unique())
    bar_positions = range(n_settings)

    # Choose a color palette
    palette = sns.color_palette("husl", n_settings)

    # Plot each bar manually
    bar_width = 0.25  # Width of the bars
    for i, setting in enumerate(df_filtered["formatted_settings"].unique()):
        # Filter the dataframe for each setting and get the mean time
        mean_time = df_filtered[df_filtered["formatted_settings"] == setting]["time (secs)"].mean()
        plt.bar(i, mean_time, width=bar_width, align="center", color=palette[i])

        # Add the text above the bars
        plt.text(i, mean_time + 0.01, f"{mean_time:.2f}", ha="center", va="bottom", fontsize=14, fontweight="bold")

    # Set the x-ticks to correspond to the settings
    plt.xticks(bar_positions, df_filtered["formatted_settings"].unique(), rotation=45, ha="right", fontsize=10)

    plt.ylabel("Time in Seconds", fontsize=14, labelpad=15)
    plt.xlabel("Settings", fontsize=14, labelpad=15)
    plt.title(args.plot_title, fontsize=18, fontweight="bold", pad=20)

    # Adding horizontal gridlines for better readability
    plt.grid(axis="y", linestyle="--", linewidth=0.7, alpha=0.7)

    plt.tight_layout()
    plt.subplots_adjust(top=0.9, bottom=0.2)  # Adjust the top and bottom

    plot_path = args.plot_title.replace(" ", "_") + ".png"
    plt.savefig(plot_path, bbox_inches="tight", dpi=300)

    if args.push_to_hub:
        upload_file(repo_id=REPO_ID, path_in_repo=plot_path, path_or_fileobj=plot_path, repo_type="dataset")
        print(
            f"Plot successfully uploaded. Find it here: https://huggingface.co/datasets/{REPO_ID}/blob/main/{args.plot_file_path}"
        )

    # Show the plot
    plt.show()