in lib/tree_util.py [0:0]
def prefix_multimap(f, treedef, tree, *rest):
"""Like tree_multimap but only maps down through a tree prefix."""
if isinstance(treedef, PyLeaf):
return f(tree, *rest)
else:
node_type = node_types.get(type(tree))
if node_type != treedef.node_type:
raise TypeError("Mismatch: {} != {}".format(treedef.node_type, node_type))
children, node_data = node_type.to_iterable(tree)
if node_data != treedef.node_data:
raise TypeError("Mismatch: {} != {}".format(treedef.node_data, node_data))
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_data:
raise TypeError("Mismatch: {} != {}".format(other_node_data, node_data))
all_children.append(other_children)
all_children = zip(*all_children)
new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)]
return node_type.from_iterable(node_data, new_children)