def walk_pytree()

in lib/tree_util.py [0:0]


def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None):
    node_type = node_types.get(type(tree))
    if treat_as_leaves is None:
        treat_as_leaves = []

    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children])
        tree_def = PyTreeDef(node_type, node_spec, child_specs)
        return f_node(proc_children), tree_def
    else:
        return f_leaf(tree), PyLeaf()