def plot_distributions()

in src/plot_utils.py [0:0]


def plot_distributions(ds_path, image_path="."):
    """Plot distribution of educational score of topics & distribution of samples accross topics"""
    ds = load_dataset(ds_path, split="train", num_proc=2, token=os.getenv("HF_TOKEN"))
    ds = ds.map(extract_score)
    print(ds["category"])
    ds = ds.filter(lambda x: x["educational_score"] not in ["None", ""])
    # distribution of scores
    df = ds.to_pandas()
    df["educational_score"] = pd.to_numeric(df["educational_score"], errors="coerce")
    df.dropna(subset=["educational_score"], inplace=True)

    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.histplot(df["educational_score"], kde=False, bins=10)
    plt.title("Distribution of Educational Scores")
    plt.xlabel("Educational Score")
    plt.ylabel("Frequency")
    plt.savefig(f"{image_path}/educational_score.png", bbox_inches="tight")

    # distribution of samples
    df = ds.to_pandas().explode("examples")
    sorted_filtered_ds = df.groupby(by="category").size().sort_values(ascending=False)
    category_df = sorted_filtered_ds.reset_index()
    category_df.columns = ["category", "number_files"]
    print(f"Saving csv in {image_path}!")
    category_df.to_csv(f"{image_path}/df_categories_count.csv")

    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(25, 20))

    barplot = sns.barplot(
        x="number_files", y="category", data=category_df, palette="Blues_d", ci=None
    )

    plt.xlabel("Number of Examples")
    plt.ylabel("Categories")
    plt.title("Histogram of Categories and their number of FW files")
    plt.tight_layout(pad=1.0)
    plt.show()
    plt.savefig(f"{image_path}/topics_distpng", bbox_inches="tight", dpi=200)