def make_datasets()

in sing/nsynth/__init__.py [0:0]


def make_datasets(dataset, valid_ratio=0.1, test_ratio=0.1, random_seed=42):
    """
    Take the original NSynth training dataset and split it into
    a train, valid and test set making sure that for a given instrument,
    a pitch is present in only one dataset (each pair of instrument and pitch
    has multiple occurences, one for each velocity).
    """

    per_pitch_instrument = defaultdict(list)

    if isinstance(dataset, NSynthDataset):
        metadata = dataset.metadata
    elif isinstance(dataset, NSynthMetadata):
        metadata = dataset
    else:
        raise ValueError(
            "Invalid dataset {}, should be an instance of "
            "either NSynthDataset or NSynthMetadata.".format(dataset))

    for index in range(len(metadata)):
        item = metadata[index]
        per_pitch_instrument[(item.metadata['instrument'],
                              item.metadata['pitch'])].append(index)

    with utils.random_seed_manager(random_seed):
        train = []
        valid = []
        test = []
        for indexes in per_pitch_instrument.values():
            score = random.random()
            if score < valid_ratio:
                valid.extend(indexes)
            elif score < valid_ratio + test_ratio:
                test.extend(indexes)
            else:
                train.extend(indexes)

        return DatasetSubset(dataset, train), DatasetSubset(
            dataset, valid), DatasetSubset(dataset, test)