def main()

in ultravox/tools/ds_tool/ds_tool.py [0:0]


def main(args: DatasetToolArgs):
    ds_name = args.dataset_name
    print(f'Loading dataset "{ds_name}" for task {args.task}')
    download_config = datasets.DownloadConfig(num_proc=args.num_workers, max_retries=2)
    data_dict: datasets.DatasetDict = datasets.load_dataset(
        ds_name,
        args.dataset_subset,
        split=args.dataset_split,
        download_config=download_config,
        revision=args.dataset_version,
    )

    if isinstance(data_dict, datasets.Dataset):
        data_dict = datasets.DatasetDict({args.upload_split: data_dict})

    if len(data_dict) > 1 and args.upload_split:
        raise ValueError("Cannot upload multiple splits to a single split")

    ds_chunk_proc = DatasetChunkProcessor(args)

    for split_name, ds_split in data_dict.items():
        print(
            f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {split_name}, containing {len(ds_split)} samples"
        )
        if args.shuffle:
            ds_split = ds_split.shuffle(seed=args.shuffle_seed)
        if args.num_samples:
            ds_split = ds_split.select(range(args.num_samples))

        ds_chunk_proc.process_and_upload_split_rescursive(
            split_name, ds_split, 0, len(ds_split)
        )