def main()

in run_pipeline.py [0:0]


def main():
    args = get_args()

    template = (
        TEMPLATE_MULTIPLE_TOPICS
        if args.topic_mode == "multiple_topics"
        else TEMPLATE_SINGLE_TOPIC
    )
    instruction = (
        INSTRUCTION_MULTIPLE_TOPICS
        if args.topic_mode == "multiple_topics"
        else INSTRUCTION_SINGLE_TOPIC
    )
    print(f"Using {args.topic_mode} for topic labeling")
    cc = ClusterClassifier(
        embed_device=args.device,
        topic_mode=args.topic_mode,
        summary_template=template,
        summary_instruction=instruction,
        dbscan_eps=args.dbscan_eps,
        dbscan_min_samples=args.dbscan_min_samples,
    )

    if args.mode == "run":
        # Run a new pipeline on texts
        dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,)
        ds = load_dataset(*dataset_args, split="train", token=True).shuffle(
            seed=42
        )

        print(ds)
        indexes = (
            range(args.start, args.end) if args.start > 0 else range(args.n_samples)
        )
        text_start = f" starting from {args.start}" if args.start > 0 else ""
        print(f"Processing {len(indexes)} samples{text_start}")

        texts = ds.select(indexes)[args.input_content]

        _, _, summaries = cc.fit(texts)
        print(f"10 example Summaries:\n{[e for e in summaries.values()][:10]}")

        cc.save(args.save_load_path)
        print(f"Saved clusters in {args.save_load_path}.")

        if args.build_hf_ds:
            build_and_push(cc, args)

        ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}"
        if args.topic_mode == "single_topic":
            plot_distributions(ds_path, image_path=args.save_load_path)
            print("📊 Saved plots for educational score and files distribution.")

    elif args.mode == "infer":
        # Run inference mode on texts using an existing pipeline
        cc.load(args.save_load_path)
        indexes = (
            range(args.start, args.end) if args.start >= 0 else range(args.n_samples)
        )
        text_start = f" starting from {args.start}" if args.start >= 0 else ""
        print(
            f"Running inference on {len(indexes)} samples{text_start} of {args.input_dataset} using clusters in {args.save_load_path}."
        )
        dataset_args = (args.input_dataset, args.data_subset) if args.data_subset else (args.input_dataset,)
        ds = load_dataset(*dataset_args, split="train", token=True)
        texts = ds.select(indexes)[args.input_content]

        start_time = time.time()
        cluster_labels, _ = cc.infer(texts, top_k=1)

        ds = build_hf_data_clusters(cc, texts, cluster_labels)
        print(f"Total time is {(time.time() - start_time)/60}min")
        target_repo = f"{args.username}/{args.inference_repo_name}"
        print(f"Samples with clusters: {ds}")
        print(f"Pushing to hub at {target_repo}...")
        ds.push_to_hub(f"{target_repo}", private=True)

    else:
        # Load existing pipeline
        if args.build_hf_ds:
            cc.load(args.save_load_path)
            build_and_push(cc, args)
            ds_path = f"{args.username}/{args.save_load_path.split('/')[-1]}"
            if args.topic_mode == "single_topic":
                plot_distributions(ds_path, image_path=args.save_load_path)
                print("📊 Saved plots for educational score and files distribution.")
        else:
            print("Using mode=load but build_hf_ds is False, nothing to be done.")

    print("Done 🎉")