datasets/spd_datasets.py (180 lines of code) (raw):
import argparse
import copy
import glob
import os
from moviepy.editor import AudioFileClip
from pydub import AudioSegment
from diarizers import SpeakerDiarizationDataset
def MP4ToMP3(mp4, mp3):
FILETOCONVERT = AudioFileClip(mp4)
FILETOCONVERT.write_audiofile(mp3)
FILETOCONVERT.close()
def get_ami_files(path_to_ami, setup="only_words", hm_type="ihm"):
"""_summary_
Returns:
_type_: _description_
"""
assert setup in ["only_words", "mini"]
assert hm_type in ["ihm", "sdm"]
rttm_files = {
"train": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "train")),
"validation": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "dev")),
"test": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "test")),
}
audio_files = {
"train": [],
"validation": [],
"test": [],
}
for subset in rttm_files:
rttm_list = copy.deepcopy(rttm_files[subset])
for rttm in rttm_list:
meeting = rttm.split("/")[-1].split(".")[0]
if hm_type == "ihm":
path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Mix-Headset.wav".format(
meeting, meeting
)
if os.path.exists(path):
audio_files[subset].append(path)
else:
rttm_files[subset].remove(rttm)
if hm_type == "sdm":
path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Array1-01.wav".format(
meeting, meeting
)
if os.path.exists(path):
audio_files[subset].append(path)
else:
rttm_files[subset].remove(rttm)
return audio_files, rttm_files
def get_callhome_files(path_to_callhome, langage="jpn"):
audio_files = glob.glob(path_to_callhome + "/callhome/{}/*.mp3".format(langage))
audio_files = {
"data": audio_files,
}
cha_files = {
"data": [],
}
for subset in audio_files:
for cha_path in audio_files[subset]:
file = cha_path.split("/")[-1].split(".")[0]
cha_files[subset].append(path_to_callhome + "/callhome/{}/{}.cha".format(langage, file))
return audio_files, cha_files
def get_callfriends_files(path_to_callfriend, langage="jpn"):
audio_files = glob.glob(path_to_callfriend + "/callfriend/{}/audio/*.mp3".format(langage))
audio_files = {
"data": audio_files,
}
cha_files = {
"data": [],
}
for subset in audio_files:
for cha_path in audio_files[subset]:
file = cha_path.split("/")[-1].split(".")[0]
cha_files[subset].append(path_to_callfriend + "/callfriend/{}/cha/{}.cha".format(langage, file))
return audio_files, cha_files
def get_sakura_files(path_to_sakura, convert_mp4_to_mp3=False):
if convert_mp4_to_mp3:
audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp4")
for mp4_path in audio_files:
mp3_path = mp4_path.split(".")[0] + ".mp3"
MP4ToMP3(mp4_path, mp3_path)
audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp3")
audio_files = {
"data": audio_files,
}
cha_files = {
"data": [],
}
for subset in audio_files:
for cha_path in audio_files[subset]:
file = cha_path.split("/")[-1].split(".")[0]
cha_files[subset].append(path_to_sakura + "/sakura/cha/{}.cha".format(file))
return audio_files, cha_files
def get_simsamu_files(path_to_simsamu):
rttm_files = glob.glob(path_to_simsamu + "/simsamu/*/*.rttm")
audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.m4a")
for file in audio_files:
sound = AudioSegment.from_file(file, format="m4a")
file.split("/")
file_hanlde = sound.export(file.split(".")[0] + ".wav", format="wav")
audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.wav")
audio_files = {"data": audio_files}
rttm_files = {"data": rttm_files}
return audio_files, rttm_files
def get_voxconverse_files(path_to_voxconverse):
rttm_files = {
"dev": glob.glob(path_to_voxconverse + "/voxconverse/dev/*.rttm"),
"test": glob.glob(path_to_voxconverse + "/voxconverse/test/*.rttm"),
}
audio_files = {
"dev": glob.glob(path_to_voxconverse + "/voxconverse/audio/*.wav"),
"test": glob.glob(path_to_voxconverse + "/voxconverse/voxconverse_test_wav/*.wav"),
}
return audio_files, rttm_files
def get_sakura_files(path_to_sakura, convert_mp4_to_mp3=False):
if convert_mp4_to_mp3:
audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp4")
for mp4_path in audio_files:
mp3_path = mp4_path.split(".")[0] + ".mp3"
MP4ToMP3(mp4_path, mp3_path)
audio_files = glob.glob(path_to_sakura + "/sakura/audio/*.mp3")
audio_files = {
"data": audio_files,
}
cha_files = {
"data": [],
}
for subset in audio_files:
for cha_path in audio_files[subset]:
file = cha_path.split("/")[-1].split(".")[0]
cha_files[subset].append(path_to_sakura + "/sakura/cha/{}.cha".format(file))
return audio_files, cha_files
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--path_to_dataset", required=True)
parser.add_argument("--setup", required=False, default="only_words")
parser.add_argument("--push_to_hub", required=False, default=False)
parser.add_argument("--hub_repository", required=False)
args = parser.parse_args()
if args.dataset == "ami":
audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="ihm")
ami_dataset_ihm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset()
if args.push_to_hub == "True":
ami_dataset_ihm.push_to_hub(args.hub_repository, "ihm")
audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="sdm")
ami_dataset_sdm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset()
if args.push_to_hub == "True":
ami_dataset_sdm.push_to_hub(args.hub_repository, "sdm")
if args.dataset == "callhome":
langages = ["eng", "jpn", "spa", "zho", "deu"]
for langage in langages:
audio_files, cha_files = get_callhome_files(args.path_to_dataset, langage=langage)
dataset = SpeakerDiarizationDataset(
audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True
).construct_dataset()
if args.push_to_hub == "True":
dataset.push_to_hub(args.hub_repository, str(langage))
if args.dataset == "simsamu":
audio_files, rttm_files = get_simsamu_files(args.path_to_dataset)
dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset()
if args.push_to_hub == "True":
dataset.push_to_hub(args.hub_repository)
if args.dataset == "voxconverse":
audio_files, rttm_files = get_voxconverse_files(args.path_to_dataset)
dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset()
if args.push_to_hub == "True":
dataset.push_to_hub(args.hub_repository)
if args.dataset == "sakura":
audio_files, cha_files = get_sakura_files(args.path_to_dataset)
sakura_dataset = SpeakerDiarizationDataset(
audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True
).construct_dataset()
if args.push_to_hub == "True":
dataset.push_to_hub(args.hub_repository)
if args.dataset == "callfriend":
langages = ["eng-s", "eng-s", "fra-q", "jpn", "spa", "spa-c", "zho-m"]
for langage in langages:
audio_files, cha_files = get_callfriends_files(args.path_to_dataset, langage=langage)
dataset = SpeakerDiarizationDataset(
audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True
).construct_dataset()
if args.push_to_hub == "True":
dataset.push_to_hub(args.hub_repository, str(langage))