in sing/nsynth/__init__.py [0:0]
def make_datasets(dataset, valid_ratio=0.1, test_ratio=0.1, random_seed=42):
"""
Take the original NSynth training dataset and split it into
a train, valid and test set making sure that for a given instrument,
a pitch is present in only one dataset (each pair of instrument and pitch
has multiple occurences, one for each velocity).
"""
per_pitch_instrument = defaultdict(list)
if isinstance(dataset, NSynthDataset):
metadata = dataset.metadata
elif isinstance(dataset, NSynthMetadata):
metadata = dataset
else:
raise ValueError(
"Invalid dataset {}, should be an instance of "
"either NSynthDataset or NSynthMetadata.".format(dataset))
for index in range(len(metadata)):
item = metadata[index]
per_pitch_instrument[(item.metadata['instrument'],
item.metadata['pitch'])].append(index)
with utils.random_seed_manager(random_seed):
train = []
valid = []
test = []
for indexes in per_pitch_instrument.values():
score = random.random()
if score < valid_ratio:
valid.extend(indexes)
elif score < valid_ratio + test_ratio:
test.extend(indexes)
else:
train.extend(indexes)
return DatasetSubset(dataset, train), DatasetSubset(
dataset, valid), DatasetSubset(dataset, test)