in lib/tree_util.py [0:0]
def tree_map(f, tree, treat_as_leaves: Optional[List] = None):
"""Map a function over a pytree to produce a new pytree.
Args:
f: function to be applied at each leaf.
tree: a pytree to be mapped over.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
if treat_as_leaves is None:
treat_as_leaves = []
node_type = node_types.get(type(tree))
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
new_children = [tree_map(f, child, treat_as_leaves) for child in children]
return node_type.from_iterable(node_spec, new_children)
else:
return f(tree)