def tree_map()

in lib/tree_util.py [0:0]


def tree_map(f, tree, treat_as_leaves: Optional[List] = None):
    """Map a function over a pytree to produce a new pytree.

    Args:
      f: function to be applied at each leaf.
      tree: a pytree to be mapped over.

    Returns:
      A new pytree with the same structure as `tree` but with the value at each
      leaf given by `f(x)` where `x` is the value at the corresponding leaf in
      `tree`.
    """
    if treat_as_leaves is None:
        treat_as_leaves = []
    node_type = node_types.get(type(tree))
    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        new_children = [tree_map(f, child, treat_as_leaves) for child in children]
        return node_type.from_iterable(node_spec, new_children)
    else:
        return f(tree)