def main()

in scripts/merge_chunks.py [0:0]


def main():
    parser = HfArgumentParser(Args)
    args = parser.parse_args_into_dataclasses()[0]
    revisions = get_dataset_revisions(args.dataset_name)

    if args.filter_strings:
        revisions = [
            revision
            for revision in revisions
            if all(filter_string in revision for filter_string in args.filter_strings)
        ]

    merged_config = revisions[0].split("--chunk")[0]
    print(f"Merging {len(revisions)} revisions to create config `{merged_config}`")

    # Prepare arguments for multiprocessing
    pool_args = [
        (args.dataset_name, revision, args.dataset_split) for revision in revisions
    ]

    # Use multiprocessing to load datasets in parallel
    with Pool(cpu_count()) as pool:
        datasets = list(
            tqdm(
                pool.imap(load_single_revision, pool_args),
                total=len(revisions),
                desc="Loading datasets",
            )
        )

    # Concatenate datasets
    merged_dataset = concatenate_datasets(datasets)

    # Sanity check
    if "problem" in merged_dataset.column_names and len(
        merged_dataset.unique("problem")
    ) != len(merged_dataset):
        raise ValueError("Found duplicate problems")
    if "lighteval_MATH" in merged_config and len(merged_dataset) != 5000:
        raise ValueError(f"Expected 5000 samples, got {len(merged_dataset)}")
    if "MATH-500" in merged_config and len(merged_dataset) != 500:
        raise ValueError(f"Expected 500 samples, got {len(merged_dataset)}")

    # Push merged dataset to the hub
    url = merged_dataset.push_to_hub(
        args.dataset_name,
        config_name=merged_config,
        split=args.dataset_split,
        private=args.hub_dataset_private,
    )
    print(f"Pushed merged dataset to {url}")