in smallpond/logical/planner.py [0:0]
def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup:
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
assert len(unique_partition_dims) == 1, f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
if node.nested:
assert (
node.dimension not in unique_partition_dims
), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}"
assert (
len(all_input_deps) * node.npartitions <= node.max_card_of_producers_x_consumers
), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}"
producer_tasks = [node.create_producer_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
return [
node.create_consumer_task(
self.runtime_ctx,
[producer],
list(producer.partition_infos) + [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
)
for producer in producer_tasks
for partition_idx in range(node.npartitions)
]
else:
max_num_producer_tasks = min(
node.max_num_producer_tasks,
math.ceil(node.max_card_of_producers_x_consumers / node.npartitions),
)
num_parallel_tasks = 2 * self.runtime_ctx.num_executors * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks))
if len(all_input_deps) < num_producer_tasks:
merge_datasets_task = node.create_merge_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])
split_dataset_tasks = [
node.create_split_task(
self.runtime_ctx,
[merge_datasets_task],
[PartitionInfo(partition_idx, num_producer_tasks)],
)
for partition_idx in range(num_producer_tasks)
]
else:
split_dataset_tasks = [
node.create_merge_task(
self.runtime_ctx,
tasks,
[PartitionInfo(partition_idx, num_producer_tasks)],
)
for partition_idx, tasks in enumerate(split_into_rows(all_input_deps, num_producer_tasks))
]
producer_tasks = [
node.create_producer_task(self.runtime_ctx, [split_dataset], split_dataset.partition_infos) for split_dataset in split_dataset_tasks
]
return [
node.create_consumer_task(
self.runtime_ctx,
producer_tasks,
[
PartitionInfo(),
PartitionInfo(partition_idx, node.npartitions, node.dimension),
],
)
for partition_idx in range(node.npartitions)
]