scripts/per_dataset_script/clean_libritts_r.py (47 lines of code) (raw):

from datasets import load_dataset from multiprocess import set_start_method import pandas as pd import argparse from os import listdir import os if __name__ == "__main__": set_start_method("spawn") parser = argparse.ArgumentParser() parser.add_argument("dataset_name", type=str, help="Repo id or local path.") parser.add_argument("bad_samples_folder", default=None, type=str, help="Path to LibriTTS-R bad folder samples.") parser.add_argument("--configuration", default=None, type=str, help="Dataset configuration to use.") parser.add_argument("--output_dir", default=None, type=str, help="If specified, save the dasaset on disk.") parser.add_argument("--repo_id", default=None, type=str, help="If specified, push the model to the hub.") parser.add_argument("--speaker_id_column_name", default="speaker_id", type=str, help="Speaker id column name.") parser.add_argument("--cpu_num_workers", default=1, type=int, help="Number of CPU workers for transformations that don't use GPUs or if no GPU are available.") args = parser.parse_args() if args.configuration: dataset = load_dataset(args.dataset_name, args.configuration) else: dataset = load_dataset(args.dataset_name) speaker_id_column_name = args.speaker_id_column_name # speakers to exclude because of mixed gender detection # cf: https://github.com/line/LibriTTS-P/blob/main/data/excluded_spk_list.txt speakers_to_remove = {2074, 4455, 6032, 3546, 2262, 8097, 1734, 3793, 8295} def filter_speakers(speaker, speakers_to_remove): return int(speaker) not in speakers_to_remove print(dataset) dataset = dataset.filter(filter_speakers, input_columns=speaker_id_column_name, num_proc=args.cpu_num_workers, fn_kwargs={"speakers_to_remove": speakers_to_remove}) print(dataset) bad_samples_txt_files = [os.path.join(args.bad_samples_folder, f) for f in listdir(args.bad_samples_folder) if "bad_sample" in f] samples_to_filter = set() for txt_file in bad_samples_txt_files: with open(txt_file, 'r') as file: for line in file: line = line.strip().split("/")[-1].split(".")[0] samples_to_filter.add(line) print(len(samples_to_filter)) def filter_samples(id, samples_to_filter): return id not in samples_to_filter dataset = dataset.filter(filter_samples, input_columns="id", num_proc=args.cpu_num_workers, fn_kwargs={"samples_to_filter": samples_to_filter}) print(dataset) if args.output_dir: dataset.save_to_disk(args.output_dir) if args.repo_id: if args.configuration: dataset.push_to_hub(args.repo_id, args.configuration) else: dataset.push_to_hub(args.repo_id)