in gym3/types_th.py [0:0]
def split(x: Any, sections: Sequence[int]) -> Any:
"""
Split the (leaf) tensors from the tree x
Examples:
split([1,2,3,4], [1,2,3,4]) => [[1], [2], [3], [4]]
split([1,2,3,4], [1,3,4]) => [[1], [2, 3], [4]]
:param x: a tree where the leaf values are torch tensors
:param sections: list of indices to split at (not sizes of each split)
:returns: list of trees with length `len(sections)` with the same shape as x
where each leaf is the corresponding section of the leaf in x
"""
# split each leaf and select the correct component
result = []
start = 0
for end in sections:
select_tree = multimap(lambda t: t[start:end], x)
start = end
result.append(select_tree)
return result