in lib/tree_util.py [0:0]
def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None):
node_type = node_types.get(type(tree))
if treat_as_leaves is None:
treat_as_leaves = []
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children])
tree_def = PyTreeDef(node_type, node_spec, child_specs)
return f_node(proc_children), tree_def
else:
return f_leaf(tree), PyLeaf()