def walk_pytree()

in ppo_ewma/tree_util.py [0:0]


def walk_pytree(f_node, f_leaf, tree):
    node_type = node_types.get(type(tree))
    if node_type:
        children, node_spec = node_type.to_iterable(tree)
        proc_children, child_specs = unzip2(
            [walk_pytree(f_node, f_leaf, child) for child in children]
        )
        tree_def = PyTreeDef(node_type, node_spec, child_specs)
        return f_node(proc_children), tree_def
    else:
        return f_leaf(tree), PyLeaf()