in smallpond/logical/planner.py [0:0]
def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup:
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
assert len(input_deps_taskgroups) == 1, f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
unique_partition_dims = set(task.partition_dims for task in input_deps_taskgroups[0])
assert len(unique_partition_dims) == 1, f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
existing_dimensions = set(unique_partition_dims.pop())
assert (
node.dimensions.intersection(existing_dimensions) == node.dimensions
), f"cannot found some of {node.dimensions} in {existing_dimensions}"
# group tasks by partitions
input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list)
for task in input_deps_taskgroups[0]:
partition_infos = tuple(info for info in task.partition_infos if info.dimension in node.dimensions)
input_deps_groupby_partitions[partition_infos].append(task)
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos) for partition_infos, input_deps in input_deps_groupby_partitions.items()
]