fn try_optimize()

in datafusion/optimizer/src/push_down_projection.rs [62:372]


    fn try_optimize(
        &self,
        plan: &LogicalPlan,
        _config: &dyn OptimizerConfig,
    ) -> Result<Option<LogicalPlan>> {
        let projection = match plan {
            LogicalPlan::Projection(projection) => projection,
            LogicalPlan::Aggregate(agg) => {
                let mut required_columns = HashSet::new();
                for e in agg.aggr_expr.iter().chain(agg.group_expr.iter()) {
                    expr_to_columns(e, &mut required_columns)?
                }
                let new_expr = get_expr(&required_columns, agg.input.schema())?;
                let projection = LogicalPlan::Projection(Projection::try_new(
                    new_expr,
                    agg.input.clone(),
                )?);
                let optimized_child = self
                    .try_optimize(&projection, _config)?
                    .unwrap_or(projection);
                return Ok(Some(plan.with_new_inputs(&[optimized_child])?));
            }
            LogicalPlan::TableScan(scan) if scan.projection.is_none() => {
                return Ok(Some(push_down_scan(&HashSet::new(), scan, false)?));
            }
            _ => return Ok(None),
        };

        let child_plan = &*projection.input;
        let projection_is_empty = projection.expr.is_empty();

        let new_plan = match child_plan {
            LogicalPlan::Projection(child_projection) => {
                let new_plan = merge_projection(projection, child_projection)?;
                self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan)
            }
            LogicalPlan::Join(join) => {
                // collect column in on/filter in join and projection.
                let mut push_columns: HashSet<Column> = HashSet::new();
                for e in projection.expr.iter() {
                    expr_to_columns(e, &mut push_columns)?;
                }
                for (l, r) in join.on.iter() {
                    expr_to_columns(l, &mut push_columns)?;
                    expr_to_columns(r, &mut push_columns)?;
                }
                if let Some(expr) = &join.filter {
                    expr_to_columns(expr, &mut push_columns)?;
                }

                let new_left = generate_projection(
                    &push_columns,
                    join.left.schema(),
                    join.left.clone(),
                )?;
                let new_right = generate_projection(
                    &push_columns,
                    join.right.schema(),
                    join.right.clone(),
                )?;
                let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;

                generate_plan!(projection_is_empty, plan, new_join)
            }
            LogicalPlan::CrossJoin(join) => {
                // collect column in on/filter in join and projection.
                let mut push_columns: HashSet<Column> = HashSet::new();
                for e in projection.expr.iter() {
                    expr_to_columns(e, &mut push_columns)?;
                }
                let new_left = generate_projection(
                    &push_columns,
                    join.left.schema(),
                    join.left.clone(),
                )?;
                let new_right = generate_projection(
                    &push_columns,
                    join.right.schema(),
                    join.right.clone(),
                )?;
                let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;

                generate_plan!(projection_is_empty, plan, new_join)
            }
            LogicalPlan::TableScan(scan)
                if !scan.projected_schema.fields().is_empty() =>
            {
                let mut used_columns: HashSet<Column> = HashSet::new();
                // filter expr may not exist in expr in projection.
                // like: TableScan: t1 projection=[bool_col, int_col], full_filters=[t1.id = Int32(1)]
                // projection=[bool_col, int_col] don't contain `ti.id`.
                exprlist_to_columns(&scan.filters, &mut used_columns)?;
                if projection_is_empty {
                    used_columns
                        .insert(scan.projected_schema.fields()[0].qualified_column());
                    push_down_scan(&used_columns, scan, true)?
                } else {
                    for expr in projection.expr.iter() {
                        expr_to_columns(expr, &mut used_columns)?;
                    }
                    let new_scan = push_down_scan(&used_columns, scan, true)?;

                    plan.with_new_inputs(&[new_scan])?
                }
            }
            LogicalPlan::Values(values) if projection_is_empty => {
                let first_col =
                    Expr::Column(values.schema.fields()[0].qualified_column());
                LogicalPlan::Projection(Projection::try_new(
                    vec![first_col],
                    Arc::new(child_plan.clone()),
                )?)
            }
            LogicalPlan::Union(union) => {
                let mut required_columns = HashSet::new();
                exprlist_to_columns(&projection.expr, &mut required_columns)?;
                // When there is no projection, we need to add the first column to the projection
                // Because if push empty down, children may output different columns.
                if required_columns.is_empty() {
                    required_columns.insert(union.schema.fields()[0].qualified_column());
                }
                // we don't push down projection expr, we just prune columns, so we just push column
                // because push expr may cause more cost.
                let projection_column_exprs = get_expr(&required_columns, &union.schema)?;
                let mut inputs = Vec::with_capacity(union.inputs.len());
                for input in &union.inputs {
                    let mut replace_map = HashMap::new();
                    for (i, field) in input.schema().fields().iter().enumerate() {
                        replace_map.insert(
                            union.schema.fields()[i].qualified_name(),
                            Expr::Column(field.qualified_column()),
                        );
                    }

                    let exprs = projection_column_exprs
                        .iter()
                        .map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
                        .collect::<Result<Vec<_>>>()?;

                    inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new(
                        exprs,
                        input.clone(),
                    )?)))
                }
                // create schema of all used columns
                let schema = DFSchema::new_with_metadata(
                    exprlist_to_fields(&projection_column_exprs, child_plan)?,
                    union.schema.metadata().clone(),
                )?;
                let new_union = LogicalPlan::Union(Union {
                    inputs,
                    schema: Arc::new(schema),
                });

                generate_plan!(projection_is_empty, plan, new_union)
            }
            LogicalPlan::SubqueryAlias(subquery_alias) => {
                let replace_map = generate_column_replace_map(subquery_alias);
                let mut required_columns = HashSet::new();
                exprlist_to_columns(&projection.expr, &mut required_columns)?;

                let new_required_columns = required_columns
                    .iter()
                    .map(|c| {
                        replace_map.get(c).cloned().ok_or_else(|| {
                            DataFusionError::Internal("replace column failed".to_string())
                        })
                    })
                    .collect::<Result<HashSet<_>>>()?;

                let new_expr =
                    get_expr(&new_required_columns, subquery_alias.input.schema())?;
                let new_projection = LogicalPlan::Projection(Projection::try_new(
                    new_expr,
                    subquery_alias.input.clone(),
                )?);
                let new_alias = child_plan.with_new_inputs(&[new_projection])?;

                generate_plan!(projection_is_empty, plan, new_alias)
            }
            LogicalPlan::Aggregate(agg) => {
                let mut required_columns = HashSet::new();
                exprlist_to_columns(&projection.expr, &mut required_columns)?;
                // Gather all columns needed for expressions in this Aggregate
                let mut new_aggr_expr = vec![];
                for e in agg.aggr_expr.iter() {
                    let column = Column::from_name(e.display_name()?);
                    if required_columns.contains(&column) {
                        new_aggr_expr.push(e.clone());
                    }
                }

                // if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it
                if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 {
                    if let Expr::AggregateFunction(AggregateFunction {
                        fun, args, ..
                    }) = &agg.aggr_expr[0]
                    {
                        if matches!(fun, datafusion_expr::AggregateFunction::Count)
                            && args.len() == 1
                            && args[0] == Expr::Literal(UInt8(Some(1)))
                        {
                            new_aggr_expr.push(agg.aggr_expr[0].clone());
                        }
                    }
                }

                let new_agg = LogicalPlan::Aggregate(Aggregate::try_new(
                    agg.input.clone(),
                    agg.group_expr.clone(),
                    new_aggr_expr,
                )?);

                generate_plan!(projection_is_empty, plan, new_agg)
            }
            LogicalPlan::Window(window) => {
                let mut required_columns = HashSet::new();
                exprlist_to_columns(&projection.expr, &mut required_columns)?;
                // Gather all columns needed for expressions in this Window
                let mut new_window_expr = vec![];
                for e in window.window_expr.iter() {
                    let column = Column::from_name(e.display_name()?);
                    if required_columns.contains(&column) {
                        new_window_expr.push(e.clone());
                    }
                }

                if new_window_expr.is_empty() {
                    // none columns in window expr are needed, remove the window expr
                    let input = window.input.clone();
                    let new_window = restrict_outputs(input.clone(), &required_columns)?
                        .unwrap_or((*input).clone());

                    generate_plan!(projection_is_empty, plan, new_window)
                } else {
                    let mut referenced_inputs = HashSet::new();
                    exprlist_to_columns(&new_window_expr, &mut referenced_inputs)?;
                    window
                        .input
                        .schema()
                        .fields()
                        .iter()
                        .filter(|f| required_columns.contains(&f.qualified_column()))
                        .for_each(|f| {
                            referenced_inputs.insert(f.qualified_column());
                        });

                    let input = window.input.clone();
                    let new_input = restrict_outputs(input.clone(), &referenced_inputs)?
                        .unwrap_or((*input).clone());
                    let new_window = LogicalPlanBuilder::from(new_input)
                        .window(new_window_expr)?
                        .build()?;

                    generate_plan!(projection_is_empty, plan, new_window)
                }
            }
            LogicalPlan::Filter(filter) => {
                if can_eliminate(projection, child_plan.schema()) {
                    // when projection schema == filter schema, we can commute directly.
                    let new_proj =
                        plan.with_new_inputs(&[filter.input.as_ref().clone()])?;
                    child_plan.with_new_inputs(&[new_proj])?
                } else {
                    let mut required_columns = HashSet::new();
                    exprlist_to_columns(&projection.expr, &mut required_columns)?;
                    exprlist_to_columns(
                        &[filter.predicate.clone()],
                        &mut required_columns,
                    )?;

                    let new_expr = get_expr(&required_columns, filter.input.schema())?;
                    let new_projection = LogicalPlan::Projection(Projection::try_new(
                        new_expr,
                        filter.input.clone(),
                    )?);
                    let new_filter = child_plan.with_new_inputs(&[new_projection])?;

                    generate_plan!(projection_is_empty, plan, new_filter)
                }
            }
            LogicalPlan::Sort(sort) => {
                if can_eliminate(projection, child_plan.schema()) {
                    // can commute
                    let new_proj = plan.with_new_inputs(&[(*sort.input).clone()])?;
                    child_plan.with_new_inputs(&[new_proj])?
                } else {
                    let mut required_columns = HashSet::new();
                    exprlist_to_columns(&projection.expr, &mut required_columns)?;
                    exprlist_to_columns(&sort.expr, &mut required_columns)?;

                    let new_expr = get_expr(&required_columns, sort.input.schema())?;
                    let new_projection = LogicalPlan::Projection(Projection::try_new(
                        new_expr,
                        sort.input.clone(),
                    )?);
                    let new_sort = child_plan.with_new_inputs(&[new_projection])?;

                    generate_plan!(projection_is_empty, plan, new_sort)
                }
            }
            LogicalPlan::Limit(limit) => {
                // can commute
                let new_proj = plan.with_new_inputs(&[limit.input.as_ref().clone()])?;
                child_plan.with_new_inputs(&[new_proj])?
            }
            _ => return Ok(None),
        };

        Ok(Some(new_plan))
    }