experiments/arena/scripts/diffusion_db_downloader.py (50 lines of code) (raw):

""" Load and process the Metadata file of the DiffusionDB dataset. This script is adapted from DiffusionDB github repo: https://github.com/poloclub/diffusiondb?tab=readme-ov-file """ import json import pandas as pd from urllib.request import urlretrieve import os from collections import defaultdict SAFETY_RATIO = 0.03 METADATA_URL = 'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata.parquet' METADATA_FILE = 'metadata.parquet' FILTERED_METADATA_FILE = 'diffusiondb_metadata.json' PROMPTS_IDS_FILE = 'prompt_image_names.json' def download_metadata(url: str, filename: str) -> None: """Downloads the metadata file from the given URL.""" print("Downloading the metadata table...") urlretrieve(url, filename) print("Download complete!") def load_metadata(filename: str) -> pd.DataFrame: """Loads the metadata table into a Pandas DataFrame.""" print("Loading the metadata table...") return pd.read_parquet(filename) def filter_metadata(df: pd.DataFrame, safety_ratio: float) -> pd.DataFrame: """Filters the metadata DataFrame based on NSFW ratios.""" filtered_df = df[ (df['image_nsfw'] < safety_ratio) & (df['prompt_nsfw'] < safety_ratio) ] print("Filtering complete!") print(f"Total number of images: {len(df)}") print(f"Number of images after filtering: {len(filtered_df)}") return filtered_df def map_unique_prompts_to_image_ids(df: pd.DataFrame) -> list: """ Creates a list of tuples (prompt, [image_ids]) from unique prompts. Args: df: The Pandas DataFrame containing the data. Returns: A list of tuples, where each tuple contains a unique prompt and a list of image_ids. """ prompt_to_image_ids = defaultdict(list) for prompt, image_id in zip(df['prompt'], df['image_name']): prompt_to_image_ids[prompt].append(image_id) print(f"Number of unique prompts: {len(prompt_to_image_ids)}") return list(prompt_to_image_ids.items()) def save_prompt_ids_to_json(prompt_ids_list: list, filename: str) -> None: """Saves the list of tuples (prompt, [image_name]) to a JSON file.""" print("Saving unique prompts to image_name...") data = {"stable_diffusion": prompt_ids_list} with open(filename, 'w') as f: json.dump(data, f, indent=4) def save_filtered_metadata(df: pd.DataFrame, filename: str) -> None: """Saves the filtered DataFrame to a JSON file.""" print("Saving the filtered metadata table...") df.to_json(filename, orient='records', indent=4) def main(): """Main function to orchestrate the metadata processing.""" if not os.path.exists(METADATA_FILE): download_metadata(METADATA_URL, METADATA_FILE) metadata_df = load_metadata(METADATA_FILE) filtered_df = filter_metadata(metadata_df, SAFETY_RATIO) prompt_image_id_list = map_unique_prompts_to_image_ids(filtered_df) save_prompt_ids_to_json(prompt_image_id_list, PROMPTS_IDS_FILE) save_filtered_metadata(filtered_df, FILTERED_METADATA_FILE) # Clean up the downloaded metadata file os.remove(METADATA_FILE) if __name__ == "__main__": main()