in ppo_ewma/tree_util.py [0:0]
def walk_pytree(f_node, f_leaf, tree):
node_type = node_types.get(type(tree))
if node_type:
children, node_spec = node_type.to_iterable(tree)
proc_children, child_specs = unzip2(
[walk_pytree(f_node, f_leaf, child) for child in children]
)
tree_def = PyTreeDef(node_type, node_spec, child_specs)
return f_node(proc_children), tree_def
else:
return f_leaf(tree), PyLeaf()