in datasets.py [0:0]
def split_command(dataset: str, path: str) -> None:
"""
Split into train / validation / test.
"""
if dataset == "ljspeech":
dset = torchaudio.datasets.LJSPEECH(root=path, download=False)
num_samples = len(dset)
split = {
"dataset": "ljspeech",
"test": list(range(0, NUM_LJSPEECH_TEST_SAMPLES)),
"validation": list(
range(
NUM_LJSPEECH_TEST_SAMPLES,
NUM_LJSPEECH_TEST_SAMPLES + NUM_LJSPEECH_VALIDATION_SAMPLES,
)
),
"train": list(
range(
NUM_LJSPEECH_TEST_SAMPLES + NUM_LJSPEECH_VALIDATION_SAMPLES,
num_samples,
)
),
}
elif dataset == "libritts":
dset_train = torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRITTS(root=path, url=split, download=False)
for split in LIBRITTS_TRAIN_SPLITS
]
)
dset_validation = torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRITTS(root=path, url=split, download=False)
for split in LIBRITTS_VALIDATION_SPLITS
]
)
dset_test = torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRITTS(root=path, url=split, download=False)
for split in LIBRITTS_TEST_SPLITS
]
)
num_train_samples = len(dset_train)
num_validation_samples = len(dset_validation)
num_test_samples = len(dset_test)
split = {
"dataset": "libritts",
"train": list(range(num_train_samples)),
"validation": list(range(num_validation_samples)),
"test": list(range(num_test_samples)),
}
elif dataset == "vctk":
dset = torchaudio.datasets.VCTK_092(root=path, download=False)
num_samples = len(dset)
num_train_samples = int(num_samples * VCTK_TRAIN_SPLIT_PRC)
num_validation_samples = int(num_samples * VCTK_VALIDATION_SPLIT_PRC)
indices = list(range(num_samples))
random.Random(42).shuffle(indices)
split = {
"dataset": "vctk",
"train": indices[0:num_train_samples],
"validation": indices[
num_train_samples : num_train_samples + num_validation_samples
],
"test": indices[num_train_samples + num_validation_samples :],
}
with open(os.path.join(path, SPLIT_JSON), "w") as handle:
json.dump(split, handle) # pyre-ignore