def tree_multimap()

in lib/tree_util.py [0:0]


def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):
    """Map a multi-input function over pytree args to produce a new pytree.

    Args:
      f: function that takes `1 + len(rest)` arguments, to be applied at the
        corresponding leaves of the pytrees.
      tree: a pytree to be mapped over, with each leaf providing the first
        positional argument to `f`.
      *rest: a tuple of pytrees, each with the same structure as `tree`.

    Returns:
      A new pytree with the same structure as `tree` but with the value at each
      leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
      in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
    """

    if treat_as_leaves is None:
        treat_as_leaves = []
    node_type = node_types.get(type(tree))
    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        all_children = [children]
        for other_tree in rest:
            other_children, other_node_data = node_type.to_iterable(other_tree)
            if other_node_data != node_spec:
                raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec))
            all_children.append(other_children)

        new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)]
        return node_type.from_iterable(node_spec, new_children)
    else:
        return f(tree, *rest)