def split()

in gym3/types_th.py [0:0]


def split(x: Any, sections: Sequence[int]) -> Any:
    """
    Split the (leaf) tensors from the tree x

    Examples:

        split([1,2,3,4], [1,2,3,4]) => [[1], [2], [3], [4]]
        split([1,2,3,4], [1,3,4]) => [[1], [2, 3], [4]]

    :param x: a tree where the leaf values are torch tensors
    :param sections: list of indices to split at (not sizes of each split)

    :returns: list of trees with length `len(sections)` with the same shape as x
            where each leaf is the corresponding section of the leaf in x
    """
    # split each leaf and select the correct component
    result = []
    start = 0
    for end in sections:
        select_tree = multimap(lambda t: t[start:end], x)
        start = end
        result.append(select_tree)
    return result