def broadcast_input_deps()

in smallpond/logical/planner.py [0:0]


    def broadcast_input_deps(self, node: Node, depth: int):
        # if no input deps, return a single partition
        if not node.input_deps:
            yield [], (PartitionInfo(),)
            return

        input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
        input_deps_most_ndims = max(
            input_deps_taskgroups,
            key=lambda taskgroup: (
                len(taskgroup[0].partition_dims),
                max(t.partition_infos for t in taskgroup),
            ),
        )
        input_deps_maps = [
            (
                taskgroup[0].partition_dims,
                dict((t.partition_infos, t) for t in taskgroup),
            )
            for taskgroup in input_deps_taskgroups
        ]

        for main_input in input_deps_most_ndims:
            input_deps = []
            for input_deps_dims, input_deps_map in input_deps_maps:
                partition_infos = tuple(info for info in main_input.partition_infos if info.dimension in input_deps_dims)
                input_dep = input_deps_map.get(partition_infos, None)
                assert (
                    input_dep is not None
                ), f"""the partition dimensions or npartitions of inputs {node.input_deps} of {repr(node)} are not compatible
  cannot match {main_input.partition_infos} against any of {input_deps_map.keys()}"""
                input_deps.append(input_dep)
            yield input_deps, main_input.partition_infos