def prefix_multimap()

in lib/tree_util.py [0:0]


def prefix_multimap(f, treedef, tree, *rest):
    """Like tree_multimap but only maps down through a tree prefix."""
    if isinstance(treedef, PyLeaf):
        return f(tree, *rest)
    else:
        node_type = node_types.get(type(tree))
        if node_type != treedef.node_type:
            raise TypeError("Mismatch: {} != {}".format(treedef.node_type, node_type))
        children, node_data = node_type.to_iterable(tree)
        if node_data != treedef.node_data:
            raise TypeError("Mismatch: {} != {}".format(treedef.node_data, node_data))
        all_children = [children]
        for other_tree in rest:
            other_children, other_node_data = node_type.to_iterable(other_tree)
            if other_node_data != node_data:
                raise TypeError("Mismatch: {} != {}".format(other_node_data, node_data))
            all_children.append(other_children)
        all_children = zip(*all_children)

        new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)]
        return node_type.from_iterable(node_data, new_children)