datasets/spd_datasets.py (180 lines of code) (raw):

import argparse import copy import glob import os from moviepy.editor import AudioFileClip from pydub import AudioSegment from diarizers import SpeakerDiarizationDataset def MP4ToMP3(mp4, mp3): FILETOCONVERT = AudioFileClip(mp4) FILETOCONVERT.write_audiofile(mp3) FILETOCONVERT.close() def get_ami_files(path_to_ami, setup="only_words", hm_type="ihm"): """_summary_ Returns: _type_: _description_ """ assert setup in ["only_words", "mini"] assert hm_type in ["ihm", "sdm"] rttm_files = { "train": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "train")), "validation": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "dev")), "test": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "test")), } audio_files = { "train": [], "validation": [], "test": [], } for subset in rttm_files: rttm_list = copy.deepcopy(rttm_files[subset]) for rttm in rttm_list: meeting = rttm.split("/")[-1].split(".")[0] if hm_type == "ihm": path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Mix-Headset.wav".format( meeting, meeting ) if os.path.exists(path): audio_files[subset].append(path) else: rttm_files[subset].remove(rttm) if hm_type == "sdm": path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Array1-01.wav".format( meeting, meeting ) if os.path.exists(path): audio_files[subset].append(path) else: rttm_files[subset].remove(rttm) return audio_files, rttm_files def get_callhome_files(path_to_callhome, langage="jpn"): audio_files = glob.glob(path_to_callhome + "/callhome/{}/*.mp3".format(langage)) audio_files = { "data": audio_files, } cha_files = { "data": [], } for subset in audio_files: for cha_path in audio_files[subset]: file = cha_path.split("/")[-1].split(".")[0] cha_files[subset].append(path_to_callhome + "/callhome/{}/{}.cha".format(langage, file)) return audio_files, cha_files def get_callfriends_files(path_to_callfriend, langage="jpn"): audio_files = glob.glob(path_to_callfriend + "/callfriend/{}/audio/*.mp3".format(langage)) audio_files = { "data": audio_files, } cha_files = { "data": [], } for subset in audio_files: for cha_path in audio_files[subset]: file = cha_path.split("/")[-1].split(".")[0] cha_files[subset].append(path_to_callfriend + "/callfriend/{}/cha/{}.cha".format(langage, file)) return audio_files, cha_files def get_sakura_files(path_to_sakura, convert_mp4_to_mp3=False): if convert_mp4_to_mp3: audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp4") for mp4_path in audio_files: mp3_path = mp4_path.split(".")[0] + ".mp3" MP4ToMP3(mp4_path, mp3_path) audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp3") audio_files = { "data": audio_files, } cha_files = { "data": [], } for subset in audio_files: for cha_path in audio_files[subset]: file = cha_path.split("/")[-1].split(".")[0] cha_files[subset].append(path_to_sakura + "/sakura/cha/{}.cha".format(file)) return audio_files, cha_files def get_simsamu_files(path_to_simsamu): rttm_files = glob.glob(path_to_simsamu + "/simsamu/*/*.rttm") audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.m4a") for file in audio_files: sound = AudioSegment.from_file(file, format="m4a") file.split("/") file_hanlde = sound.export(file.split(".")[0] + ".wav", format="wav") audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.wav") audio_files = {"data": audio_files} rttm_files = {"data": rttm_files} return audio_files, rttm_files def get_voxconverse_files(path_to_voxconverse): rttm_files = { "dev": glob.glob(path_to_voxconverse + "/voxconverse/dev/*.rttm"), "test": glob.glob(path_to_voxconverse + "/voxconverse/test/*.rttm"), } audio_files = { "dev": glob.glob(path_to_voxconverse + "/voxconverse/audio/*.wav"), "test": glob.glob(path_to_voxconverse + "/voxconverse/voxconverse_test_wav/*.wav"), } return audio_files, rttm_files def get_sakura_files(path_to_sakura, convert_mp4_to_mp3=False): if convert_mp4_to_mp3: audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp4") for mp4_path in audio_files: mp3_path = mp4_path.split(".")[0] + ".mp3" MP4ToMP3(mp4_path, mp3_path) audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp3") audio_files = { "data": audio_files, } cha_files = { "data": [], } for subset in audio_files: for cha_path in audio_files[subset]: file = cha_path.split("/")[-1].split(".")[0] cha_files[subset].append(path_to_sakura + "/sakura/cha/{}.cha".format(file)) return audio_files, cha_files if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset", required=True) parser.add_argument("--path_to_dataset", required=True) parser.add_argument("--setup", required=False, default="only_words") parser.add_argument("--push_to_hub", required=False, default=False) parser.add_argument("--hub_repository", required=False) args = parser.parse_args() if args.dataset == "ami": audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="ihm") ami_dataset_ihm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() if args.push_to_hub == "True": ami_dataset_ihm.push_to_hub(args.hub_repository, "ihm") audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="sdm") ami_dataset_sdm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() if args.push_to_hub == "True": ami_dataset_sdm.push_to_hub(args.hub_repository, "sdm") if args.dataset == "callhome": langages = ["eng", "jpn", "spa", "zho", "deu"] for langage in langages: audio_files, cha_files = get_callhome_files(args.path_to_dataset, langage=langage) dataset = SpeakerDiarizationDataset( audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True ).construct_dataset() if args.push_to_hub == "True": dataset.push_to_hub(args.hub_repository, str(langage)) if args.dataset == "simsamu": audio_files, rttm_files = get_simsamu_files(args.path_to_dataset) dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() if args.push_to_hub == "True": dataset.push_to_hub(args.hub_repository) if args.dataset == "voxconverse": audio_files, rttm_files = get_voxconverse_files(args.path_to_dataset) dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() if args.push_to_hub == "True": dataset.push_to_hub(args.hub_repository) if args.dataset == "sakura": audio_files, cha_files = get_sakura_files(args.path_to_dataset) sakura_dataset = SpeakerDiarizationDataset( audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True ).construct_dataset() if args.push_to_hub == "True": dataset.push_to_hub(args.hub_repository) if args.dataset == "callfriend": langages = ["eng-s", "eng-s", "fra-q", "jpn", "spa", "spa-c", "zho-m"] for langage in langages: audio_files, cha_files = get_callfriends_files(args.path_to_dataset, langage=langage) dataset = SpeakerDiarizationDataset( audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True ).construct_dataset() if args.push_to_hub == "True": dataset.push_to_hub(args.hub_repository, str(langage))