in lib/tree_util.py [0:0]
def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):
"""Map a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each with the same structure as `tree`.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
"""
if treat_as_leaves is None:
treat_as_leaves = []
node_type = node_types.get(type(tree))
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_spec:
raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec))
all_children.append(other_children)
new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)]
return node_type.from_iterable(node_spec, new_children)
else:
return f(tree, *rest)