fn stages()

in src/dataframe.rs [90:222]


    fn stages(
        &mut self,
        py: Python,
        batch_size: usize,
        prefetch_buffer_size: usize,
        partitions_per_worker: Option<usize>,
    ) -> PyResult<Vec<PyDFRayStage>> {
        let mut stages = vec![];

        let mut partition_groups = vec![];
        let mut full_partitions = false;
        // We walk up the tree from the leaves to find the stages, record ray stages, and replace
        // each ray stage with a corresponding ray reader stage.
        let up = |plan: Arc<dyn ExecutionPlan>| {
            trace!(
                "Examining plan up: {}",
                displayable(plan.as_ref()).one_line()
            );

            if let Some(stage_exec) = plan.as_any().downcast_ref::<DFRayStageExec>() {
                trace!("ray stage exec");
                let input = plan.children();
                assert!(input.len() == 1, "RayStageExec must have exactly one child");
                let input = input[0];

                let replacement = Arc::new(DFRayStageReaderExec::try_new(
                    plan.output_partitioning().clone(),
                    input.schema(),
                    stage_exec.stage_id,
                )?) as Arc<dyn ExecutionPlan>;

                let stage = PyDFRayStage::new(
                    stage_exec.stage_id,
                    input.clone(),
                    partition_groups.clone(),
                    full_partitions,
                );
                partition_groups = vec![];
                full_partitions = false;

                stages.push(stage);
                Ok(Transformed::yes(replacement))
            } else if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
                trace!("repartition exec");
                let (calculated_partition_groups, replacement) = build_replacement(
                    plan,
                    prefetch_buffer_size,
                    partitions_per_worker,
                    true,
                    batch_size,
                    batch_size,
                )?;
                partition_groups = calculated_partition_groups;

                Ok(Transformed::yes(replacement))
            } else if plan.as_any().downcast_ref::<SortExec>().is_some() {
                trace!("sort exec");
                let (calculated_partition_groups, replacement) = build_replacement(
                    plan,
                    prefetch_buffer_size,
                    partitions_per_worker,
                    false,
                    batch_size,
                    batch_size,
                )?;
                partition_groups = calculated_partition_groups;
                full_partitions = true;

                Ok(Transformed::yes(replacement))
            } else if plan.as_any().downcast_ref::<NestedLoopJoinExec>().is_some() {
                trace!("nested loop join exec");
                // NestedLoopJoinExec must be on a stage by itself as it materializes the entire left
                // side of the join and is not suitable to be executed in a partitioned manner.
                let mut replacement = plan.clone();
                let partition_count = plan.output_partitioning().partition_count();
                trace!("nested join output partitioning {}", partition_count);

                replacement = Arc::new(MaxRowsExec::new(
                    Arc::new(CoalesceBatchesExec::new(replacement, batch_size))
                        as Arc<dyn ExecutionPlan>,
                    batch_size,
                )) as Arc<dyn ExecutionPlan>;

                if prefetch_buffer_size > 0 {
                    replacement = Arc::new(PrefetchExec::new(replacement, prefetch_buffer_size))
                        as Arc<dyn ExecutionPlan>;
                }
                partition_groups = vec![(0..partition_count).collect()];
                full_partitions = true;
                Ok(Transformed::yes(replacement))
            } else {
                trace!("not special case");
                Ok(Transformed::no(plan))
            }
        };

        let physical_plan = wait_for_future(py, self.df.clone().create_physical_plan())?;

        physical_plan.transform_up(up)?;

        // add coalesce and max rows to last stage
        let mut last_stage = stages
            .pop()
            .ok_or(internal_datafusion_err!("No stages found"))?;

        if last_stage.num_output_partitions() > 1 {
            return internal_err!("Last stage expected to have one partition").to_py_err();
        }

        last_stage = PyDFRayStage::new(
            last_stage.stage_id,
            Arc::new(MaxRowsExec::new(
                Arc::new(CoalesceBatchesExec::new(last_stage.plan, batch_size))
                    as Arc<dyn ExecutionPlan>,
                batch_size,
            )) as Arc<dyn ExecutionPlan>,
            vec![vec![0]],
            true,
        );

        // done fixing last stage

        let reader_plan = Arc::new(DFRayStageReaderExec::try_new_from_input(
            last_stage.plan.clone(),
            last_stage.stage_id,
        )?) as Arc<dyn ExecutionPlan>;

        stages.push(last_stage);

        self.final_plan = Some(reader_plan);

        Ok(stages)
    }