utils/hf_dataset_subsampling.py (71 lines of code) (raw):
import argparse
import os
from typing import List, Dict
import subprocess
import shlex
import numpy as np
import pyarrow as pa
from datasets import load_dataset, Dataset, concatenate_datasets
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, required=True,
help="Path to the dataset you're using on the HF hub. Pass e.g. `csv` or `json` and `data_files=path_on_disk` to load something locally")
parser.add_argument('--subset', type=str, default=None, help="Subset of the dataset you're using, if needed")
parser.add_argument('--data_files', type=str, default=None, help="Path to the dataset on disk if using local files")
parser.add_argument('--ratios', nargs='+', type=float, help="Subsampling ratios", required=True)
parser.add_argument('--names', nargs='+', type=str, help="Names for the produced subsets", required=False)
parser.add_argument('--pre_shuffle', action="store_true", help="Whether to shuffle the dataset in advance")
parser.add_argument('--shuffle_seed', type=int, default=0, help="Shuffling seed")
return parser.parse_args()
def get_size_per_example(texts: List[str]) -> Dict:
size_values = [len(text.encode()) for text in texts]
examples = {"bytes_len": size_values}
return examples
def get_total_byte_size(dataset):
return pa.compute.sum(dataset.data["bytes_len"]).as_py()
def output_path(args, ratio, name):
if name is None:
name = f"{ratio}_subsample"
if args.data_files is not None:
# assumes there's an extension
path = args.data_files.split(".")[:-1]
path += f"_{name}"
path += ".jsonl"
else:
path = f"{args.name}_{args.subset}_{name}.jsonl"
return os.path.abspath(path)
if __name__ == "__main__":
args = get_args()
if args.names is None:
args.names = [None] * len(args.ratios)
else:
assert len(args.names) == len(args.ratios)
dataset = load_dataset(args.name, args.subset, data_files=args.data_files, num_proc=os.cpu_count(), split="train")
dataset = dataset.map(
get_size_per_example,
batched=True,
num_proc=os.cpu_count(),
batch_size=1024,
input_columns=["text"],
)
if args.pre_shuffle:
# this is going to be incredibly slow on large datasets
dataset = dataset.shuffle(args.shuffle_seed)
dataset = dataset.flatten_indices(num_proc=os.cpu_count())
cumsum_sizes = pa.compute.cumulative_sum(dataset.data["bytes_len"])
cumsum_ds = Dataset(pa.Table.from_arrays([cumsum_sizes], names=["cumsum_sizes"]))
dataset = concatenate_datasets([dataset, cumsum_ds], axis=1)
total_size = dataset[-1]["cumsum_sizes"]
dataset = dataset.with_format("numpy")
ratios_and_names = sorted(list(zip(args.ratios, args.names)), key=lambda x: x[0], reverse=True)
base_file = args.data_files
assert dataset._indices is None
for ratio, name in tqdm(ratios_and_names):
cutoff_point = np.searchsorted(dataset["cumsum_sizes"], total_size * ratio)
if base_file is None:
subset = dataset.select(range(cutoff_point)).remove_columns(["bytes_len", "cumsum_sizes"])
assert subset._indices is None
subset.to_json(output_path(args, ratio, name), num_proc=64, batch_size=100_000)
base_file = output_path(args, ratio, name)
else:
subprocess.run(shlex.split(f"head -{cutoff_point} {base_file}"),
stdout=open(output_path(args, ratio, name), "w"), check=True)