scripts/merge_chunks.py (63 lines of code) (raw):

#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from multiprocessing import Pool, cpu_count from typing import List from datasets import concatenate_datasets, load_dataset from tqdm.auto import tqdm from transformers import HfArgumentParser from sal.utils.hub import get_dataset_revisions """Merge revisions of a dataset into a single config. Usage: # Merge all revisions of a dataset for a given seed python scripts/merge_chunks.py \ --dataset_name HuggingFaceH4/Llama-3.2-1B-Instruct-best-of-N-completions \ --filter_strings seed-0 # Merge only revisions that contain "last" or "T-0.0" or "seed-0" in their name python scripts/merge_chunks.py \ --dataset_name HuggingFaceH4/Llama-3.2-1B-Instruct-best-of-N-completions \ --filter_strings last T-0.0 seed-0 """ @dataclass class Args: dataset_name: str dataset_split: str = "train" filter_strings: List[str] = field(default_factory=list) hub_dataset_private: bool = False def load_single_revision(args): """Load a single dataset revision.""" dataset_name, revision, dataset_split = args return load_dataset( dataset_name, revision=revision, trust_remote_code=True, split=dataset_split, download_mode="force_redownload", ) def main(): parser = HfArgumentParser(Args) args = parser.parse_args_into_dataclasses()[0] revisions = get_dataset_revisions(args.dataset_name) if args.filter_strings: revisions = [ revision for revision in revisions if all(filter_string in revision for filter_string in args.filter_strings) ] merged_config = revisions[0].split("--chunk")[0] print(f"Merging {len(revisions)} revisions to create config `{merged_config}`") # Prepare arguments for multiprocessing pool_args = [ (args.dataset_name, revision, args.dataset_split) for revision in revisions ] # Use multiprocessing to load datasets in parallel with Pool(cpu_count()) as pool: datasets = list( tqdm( pool.imap(load_single_revision, pool_args), total=len(revisions), desc="Loading datasets", ) ) # Concatenate datasets merged_dataset = concatenate_datasets(datasets) # Sanity check if "problem" in merged_dataset.column_names and len( merged_dataset.unique("problem") ) != len(merged_dataset): raise ValueError("Found duplicate problems") if "lighteval_MATH" in merged_config and len(merged_dataset) != 5000: raise ValueError(f"Expected 5000 samples, got {len(merged_dataset)}") if "MATH-500" in merged_config and len(merged_dataset) != 500: raise ValueError(f"Expected 500 samples, got {len(merged_dataset)}") # Push merged dataset to the hub url = merged_dataset.push_to_hub( args.dataset_name, config_name=merged_config, split=args.dataset_split, private=args.hub_dataset_private, ) print(f"Pushed merged dataset to {url}") if __name__ == "__main__": main()