in scripts/merge_chunks.py [0:0]
def main():
parser = HfArgumentParser(Args)
args = parser.parse_args_into_dataclasses()[0]
revisions = get_dataset_revisions(args.dataset_name)
if args.filter_strings:
revisions = [
revision
for revision in revisions
if all(filter_string in revision for filter_string in args.filter_strings)
]
merged_config = revisions[0].split("--chunk")[0]
print(f"Merging {len(revisions)} revisions to create config `{merged_config}`")
# Prepare arguments for multiprocessing
pool_args = [
(args.dataset_name, revision, args.dataset_split) for revision in revisions
]
# Use multiprocessing to load datasets in parallel
with Pool(cpu_count()) as pool:
datasets = list(
tqdm(
pool.imap(load_single_revision, pool_args),
total=len(revisions),
desc="Loading datasets",
)
)
# Concatenate datasets
merged_dataset = concatenate_datasets(datasets)
# Sanity check
if "problem" in merged_dataset.column_names and len(
merged_dataset.unique("problem")
) != len(merged_dataset):
raise ValueError("Found duplicate problems")
if "lighteval_MATH" in merged_config and len(merged_dataset) != 5000:
raise ValueError(f"Expected 5000 samples, got {len(merged_dataset)}")
if "MATH-500" in merged_config and len(merged_dataset) != 500:
raise ValueError(f"Expected 500 samples, got {len(merged_dataset)}")
# Push merged dataset to the hub
url = merged_dataset.push_to_hub(
args.dataset_name,
config_name=merged_config,
split=args.dataset_split,
private=args.hub_dataset_private,
)
print(f"Pushed merged dataset to {url}")