def generate_splits()

in privacy_lint/dataset/masks.py [0:0]


def generate_splits(n_data: int, split_config: dict):
    """
    Generate splits for a dataset of n_data samples, with split_config specifying how to divide data samples

    """
    flat_config = flatten(split_config)
    flat_config = multiply_round(n_data, flat_config)

    permutation = np.random.permutation(n_data)
    masks = {}
    offset = 0
    for split, n_split in flat_config.items():
        masks[split] = idx_to_mask(n_data, permutation[offset : offset + n_split])
        offset += n_split

    return masks