in run_pipeline.py [0:0]
def main():
args = get_args()
template = (
TEMPLATE_MULTIPLE_TOPICS
if args.topic_mode == "multiple_topics"
else TEMPLATE_SINGLE_TOPIC
)
instruction = (
INSTRUCTION_MULTIPLE_TOPICS
if args.topic_mode == "multiple_topics"
else INSTRUCTION_SINGLE_TOPIC
)
print(f"Using {args.topic_mode} for topic labeling")
cc = ClusterClassifier(
embed_device=args.device,
topic_mode=args.topic_mode,
summary_template=template,
summary_instruction=instruction,
dbscan_eps=args.dbscan_eps,
dbscan_min_samples=args.dbscan_min_samples,
)
if args.mode == "run":
# Run a new pipeline on texts
dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,)
ds = load_dataset(*dataset_args, split="train", token=True).shuffle(
seed=42
)
print(ds)
indexes = (
range(args.start, args.end) if args.start > 0 else range(args.n_samples)
)
text_start = f" starting from {args.start}" if args.start > 0 else ""
print(f"Processing {len(indexes)} samples{text_start}")
texts = ds.select(indexes)[args.input_content]
_, _, summaries = cc.fit(texts)
print(f"10 example Summaries:\n{[e for e in summaries.values()][:10]}")
cc.save(args.save_load_path)
print(f"Saved clusters in {args.save_load_path}.")
if args.build_hf_ds:
build_and_push(cc, args)
ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}"
if args.topic_mode == "single_topic":
plot_distributions(ds_path, image_path=args.save_load_path)
print("📊 Saved plots for educational score and files distribution.")
elif args.mode == "infer":
# Run inference mode on texts using an existing pipeline
cc.load(args.save_load_path)
indexes = (
range(args.start, args.end) if args.start >= 0 else range(args.n_samples)
)
text_start = f" starting from {args.start}" if args.start >= 0 else ""
print(
f"Running inference on {len(indexes)} samples{text_start} of {args.input_dataset} using clusters in {args.save_load_path}."
)
dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,)
ds = load_dataset(*dataset_args, split="train", token=True)
texts = ds.select(indexes)[args.input_content]
start_time = time.time()
cluster_labels, _ = cc.infer(texts, top_k=1)
ds = build_hf_data_clusters(cc, texts, cluster_labels)
print(f"Total time is {(time.time() - start_time)/60}min")
target_repo = f"{args.username}/{args.inference_repo_name}"
print(f"Samples with clusters: {ds}")
print(f"Pushing to hub at {target_repo}...")
ds.push_to_hub(f"{target_repo}", private=True)
else:
# Load existing pipeline
if args.build_hf_ds:
cc.load(args.save_load_path)
build_and_push(cc, args)
ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}"
if args.topic_mode == "single_topic":
plot_distributions(ds_path, image_path=args.save_load_path)
print("📊 Saved plots for educational score and files distribution.")
else:
print("Using mode=load but build_hf_ds is False, nothing to be done.")
print("Done 🎉")