in ml3/mbrl_utils.py [0:0]
def split_to_subsets(X, Y, K):
if K == 1:
# for 1 split, do not resshuffle dataset
return [Dataset(X, Y)]
n_data = len(X)
chunk_sz = int(math.ceil(n_data / K))
all_idx = np.random.permutation(n_data)
datasets = []
# each dataset contains
for i in range(K):
start_idx = i * (chunk_sz)
end_idx = min(start_idx + chunk_sz, n_data)
dataset_idx = np.delete(all_idx, range(start_idx, end_idx), axis=0)
X_subset = [X[idx] for idx in dataset_idx]
Y_subset = [Y[idx] for idx in dataset_idx]
datasets.append(Dataset(X_subset, Y_subset))
return datasets