run_pipeline.py (199 lines of code) (raw):

import argparse import textwrap import time import numpy as np import pandas as pd from datasets import Dataset, load_dataset from src.plot_utils import plot_distributions from src.text_clustering import ClusterClassifier INSTRUCTION_SINGLE_TOPIC = "The examples below are web samples from the same cluster, identify the topic they have in common, for example: Philosophy, Lifesyle, Linear Algebra, Biochemistry, Economics...\ Additionally determine if the topics in the examples \ are broadly suitable as college/school material, while being mindful to exclude any sensitive/inappropriate/irrelevant content, \ including but not limited to sex, explicit violence, ads & scams, and other non-academic subjects. Consider a wide range of content including scientific, \ educational, historical, cultural, and practical applications and give a rating of how educational these topics could be from 1 to 10, 1 being extremely un-educational \ and inapproriate for an education setting and 10 being highly educational. The output format should be like this: Topic: the_topic, Educational value rating: score." INSTRUCTION_MULTIPLE_TOPICS = "Use three words total (comma separated)\ to describe general topics in above texts. Under no circumstances use enumeration. \ Example format: Tree, Cat, Fireman" TEMPLATE_MULTIPLE_TOPICS = "<s>[INST]{examples}\n\n{instruction}[/INST]" TEMPLATE_SINGLE_TOPIC = "<s>[INST]{instruction}\n\nExamples:\n{examples}\nRemember that the output format should be like this: Topic: the_topic, Educational value rating: score.[/INST]" def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--n_samples", type=int, default=100_000) parser.add_argument("--start", type=int, default=-1) parser.add_argument("--end", type=int, default=100_000) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--save_load_path", type=str, default="./cc_100k") parser.add_argument( "--input_dataset", type=str, default="HuggingFaceFW/FW-12-12-2023-CC-2023-06", help="dataset with the samples to use for clustering", ) parser.add_argument( "--data_subset", type=str, default=None, help="dataset subset", ) parser.add_argument("--input_content", type=str, default="content") parser.add_argument( "--topic_mode", type=str, choices=["single_topic", "multiple_topics"], default="multiple_topics", help="Specify 'single_topic' to generate only one topic and score its educational value, or 'multiple_topics' to generate the 3 most relevant topics in the cluster.", ) parser.add_argument( "--dbscan_eps", type=float, default=0.08, help="The maximum distance between two samples for them to be considered as in the neighborhood of each other.", ) parser.add_argument( "--dbscan_min_samples", type=int, default=50, help="The number of samples in a neighborhood for a point to be considered as a core point.", ) parser.add_argument( "--mode", choices=["run", "load", "infer"], default="run", help="Run the pipeline from scratch/load existing model to build hf datasets or to infer on new texts", ) parser.add_argument( "--inference_repo_name", type=str, default="infer_fw_on_ultrachat", help="HF repo name for the clusters dataset in inference mode", ) parser.add_argument( "--build_hf_ds", action="store_true", help="Builds HF datasets used for space visualization and pushes them to the hub", ) parser.add_argument("--username", type=str, default="loubnabnl") return parser.parse_args() def extract_res(example): summary = example["summary"] category = summary.split(". Educational")[0].strip() score = summary.split(" Educational score: ")[1].strip() return {"category": category, "educational_score": score} def build_hf_data_clusters(cc, texts=None, labels=None): """ Build an HF dataset containing information on each cluster. Args: cc: ClusterClassifier object. texts: list of texts used for inference mode. labels: list of cluster labels corresponding to the texts for inference mode. If `texts` and `labels` are not provided, the function will use the data available in `cc` to construct the dataset. Otherwise it will run in inference mode on texts. """ cluster_data = [] for cluster_id in cc.label2docs.keys(): if cluster_id == -1: continue # inference mode if texts is not None and labels is not None: labels_array = np.array(labels) files_in_cluster = np.where(labels_array == cluster_id)[0] examples = [texts[doc_id] for doc_id in files_in_cluster] else: doc_ids = cc.label2docs[cluster_id] examples = [cc.texts[doc_id] for doc_id in doc_ids] cluster_info = { "cluster_id": cluster_id, "summary": cc.cluster_summaries[cluster_id], "examples": examples, } if not texts: cluster_info["position"] = cc.cluster_centers[cluster_id] cluster_data.append(cluster_info) return Dataset.from_pandas(pd.DataFrame(cluster_data)) def build_hf_data_files(cc): """ Build an HF dataset containing information on each file and the cluster they belong to """ df = pd.DataFrame( data={ "X": cc.projections[:, 0], "Y": cc.projections[:, 1], "labels": cc.cluster_labels, "content_display": [textwrap.fill(txt[:1024], 64) for txt in cc.texts], } ) return Dataset.from_pandas(df) def build_and_push(cc, args): """Build HF files & clusters datasts and push them to the hub""" print("Building HF datasets...") ds = build_hf_data_clusters(cc) ds = ds.map(extract_res) data_clusters = build_hf_data_files(cc) print(f"Files dataset {ds}\nClusters dataset {data_clusters}") repo_name = args.save_load_path.split("/")[-1] print(f"Pushing to the hub at {repo_name}...") ds.push_to_hub(f"{args.username}/{repo_name}", private=True) data_clusters.push_to_hub(f"{args.username}/{repo_name}_clusters", private=True) def main(): args = get_args() template = ( TEMPLATE_MULTIPLE_TOPICS if args.topic_mode == "multiple_topics" else TEMPLATE_SINGLE_TOPIC ) instruction = ( INSTRUCTION_MULTIPLE_TOPICS if args.topic_mode == "multiple_topics" else INSTRUCTION_SINGLE_TOPIC ) print(f"Using {args.topic_mode} for topic labeling") cc = ClusterClassifier( embed_device=args.device, topic_mode=args.topic_mode, summary_template=template, summary_instruction=instruction, dbscan_eps=args.dbscan_eps, dbscan_min_samples=args.dbscan_min_samples, ) if args.mode == "run": # Run a new pipeline on texts dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,) ds = load_dataset(*dataset_args, split="train", token=True).shuffle( seed=42 ) print(ds) indexes = ( range(args.start, args.end) if args.start > 0 else range(args.n_samples) ) text_start = f" starting from {args.start}" if args.start > 0 else "" print(f"Processing {len(indexes)} samples{text_start}") texts = ds.select(indexes)[args.input_content] _, _, summaries = cc.fit(texts) print(f"10 example Summaries:\n{[e for e in summaries.values()][:10]}") cc.save(args.save_load_path) print(f"Saved clusters in {args.save_load_path}.") if args.build_hf_ds: build_and_push(cc, args) ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}" if args.topic_mode == "single_topic": plot_distributions(ds_path, image_path=args.save_load_path) print("📊 Saved plots for educational score and files distribution.") elif args.mode == "infer": # Run inference mode on texts using an existing pipeline cc.load(args.save_load_path) indexes = ( range(args.start, args.end) if args.start >= 0 else range(args.n_samples) ) text_start = f" starting from {args.start}" if args.start >= 0 else "" print( f"Running inference on {len(indexes)} samples{text_start} of {args.input_dataset} using clusters in {args.save_load_path}." ) dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,) ds = load_dataset(*dataset_args, split="train", token=True) texts = ds.select(indexes)[args.input_content] start_time = time.time() cluster_labels, _ = cc.infer(texts, top_k=1) ds = build_hf_data_clusters(cc, texts, cluster_labels) print(f"Total time is {(time.time() - start_time)/60}min") target_repo = f"{args.username}/{args.inference_repo_name}" print(f"Samples with clusters: {ds}") print(f"Pushing to hub at {target_repo}...") ds.push_to_hub(f"{target_repo}", private=True) else: # Load existing pipeline if args.build_hf_ds: cc.load(args.save_load_path) build_and_push(cc, args) ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}" if args.topic_mode == "single_topic": plot_distributions(ds_path, image_path=args.save_load_path) print("📊 Saved plots for educational score and files distribution.") else: print("Using mode=load but build_hf_ds is False, nothing to be done.") print("Done 🎉") if __name__ == "__main__": main()