def tree_multimap()

in ppo_ewma/tree_util.py [0:0]


def tree_multimap(f, tree, *rest):
    """Map a multi-input function over pytree args to produce a new pytree.

  Args:
    f: function that takes `1 + len(rest)` arguments, to be applied at the
      corresponding leaves of the pytrees.
    tree: a pytree to be mapped over, with each leaf providing the first
      positional argument to `f`.
    *rest: a tuple of pytrees, each with the same structure as `tree`.

  Returns:
    A new pytree with the same structure as `tree` but with the value at each
    leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
    in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
  """
    # equivalent to prefix_multimap(f, tree_structure(tree), tree, *rest)
    node_type = node_types.get(type(tree))
    if node_type:
        children, node_spec = node_type.to_iterable(tree)
        all_children = [children]
        for other_tree in rest:
            # other_node_type = node_types.get(type(other_tree))
            # if node_type != other_node_type:
            #   raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
            other_children, other_node_data = node_type.to_iterable(other_tree)
            if other_node_data != node_spec:
                raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec))
            all_children.append(other_children)

        new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
        return node_type.from_iterable(node_spec, new_children)
    else:
        return f(tree, *rest)