def visit_partition_node()

in smallpond/logical/planner.py [0:0]


    def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup:
        all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
        unique_partition_dims = set(task.partition_dims for task in all_input_deps)
        assert len(unique_partition_dims) == 1, f"cannot partition input_deps with different dimensions: {unique_partition_dims}"

        if node.nested:
            assert (
                node.dimension not in unique_partition_dims
            ), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}"
            assert (
                len(all_input_deps) * node.npartitions <= node.max_card_of_producers_x_consumers
            ), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}"
            producer_tasks = [node.create_producer_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
            return [
                node.create_consumer_task(
                    self.runtime_ctx,
                    [producer],
                    list(producer.partition_infos) + [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
                )
                for producer in producer_tasks
                for partition_idx in range(node.npartitions)
            ]
        else:
            max_num_producer_tasks = min(
                node.max_num_producer_tasks,
                math.ceil(node.max_card_of_producers_x_consumers / node.npartitions),
            )
            num_parallel_tasks = 2 * self.runtime_ctx.num_executors * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
            num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks))
            if len(all_input_deps) < num_producer_tasks:
                merge_datasets_task = node.create_merge_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])
                split_dataset_tasks = [
                    node.create_split_task(
                        self.runtime_ctx,
                        [merge_datasets_task],
                        [PartitionInfo(partition_idx, num_producer_tasks)],
                    )
                    for partition_idx in range(num_producer_tasks)
                ]
            else:
                split_dataset_tasks = [
                    node.create_merge_task(
                        self.runtime_ctx,
                        tasks,
                        [PartitionInfo(partition_idx, num_producer_tasks)],
                    )
                    for partition_idx, tasks in enumerate(split_into_rows(all_input_deps, num_producer_tasks))
                ]
            producer_tasks = [
                node.create_producer_task(self.runtime_ctx, [split_dataset], split_dataset.partition_infos) for split_dataset in split_dataset_tasks
            ]
            return [
                node.create_consumer_task(
                    self.runtime_ctx,
                    producer_tasks,
                    [
                        PartitionInfo(),
                        PartitionInfo(partition_idx, node.npartitions, node.dimension),
                    ],
                )
                for partition_idx in range(node.npartitions)
            ]