def split_to_subsets()

in ml3/mbrl_utils.py [0:0]


def split_to_subsets(X, Y, K):
    if K == 1:
        # for 1 split, do not resshuffle dataset
        return [Dataset(X, Y)]

    n_data = len(X)
    chunk_sz = int(math.ceil(n_data / K))
    all_idx = np.random.permutation(n_data)

    datasets = []
    # each dataset contains
    for i in range(K):
        start_idx = i * (chunk_sz)
        end_idx = min(start_idx + chunk_sz, n_data)
        dataset_idx = np.delete(all_idx, range(start_idx, end_idx), axis=0)
        X_subset = [X[idx] for idx in dataset_idx]
        Y_subset = [Y[idx] for idx in dataset_idx]
        datasets.append(Dataset(X_subset, Y_subset))

    return datasets