in smallpond/logical/planner.py [0:0]
def broadcast_input_deps(self, node: Node, depth: int):
# if no input deps, return a single partition
if not node.input_deps:
yield [], (PartitionInfo(),)
return
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
input_deps_most_ndims = max(
input_deps_taskgroups,
key=lambda taskgroup: (
len(taskgroup[0].partition_dims),
max(t.partition_infos for t in taskgroup),
),
)
input_deps_maps = [
(
taskgroup[0].partition_dims,
dict((t.partition_infos, t) for t in taskgroup),
)
for taskgroup in input_deps_taskgroups
]
for main_input in input_deps_most_ndims:
input_deps = []
for input_deps_dims, input_deps_map in input_deps_maps:
partition_infos = tuple(info for info in main_input.partition_infos if info.dimension in input_deps_dims)
input_dep = input_deps_map.get(partition_infos, None)
assert (
input_dep is not None
), f"""the partition dimensions or npartitions of inputs {node.input_deps} of {repr(node)} are not compatible
cannot match {main_input.partition_infos} against any of {input_deps_map.keys()}"""
input_deps.append(input_dep)
yield input_deps, main_input.partition_infos