fn rewrite()

in datafusion/optimizer/src/push_down_filter.rs [766:1221]


    fn rewrite(
        &self,
        plan: LogicalPlan,
        _config: &dyn OptimizerConfig,
    ) -> Result<Transformed<LogicalPlan>> {
        if let LogicalPlan::Join(join) = plan {
            return push_down_join(join, None);
        };

        let plan_schema = Arc::clone(plan.schema());

        let LogicalPlan::Filter(mut filter) = plan else {
            return Ok(Transformed::no(plan));
        };

        match Arc::unwrap_or_clone(filter.input) {
            LogicalPlan::Filter(child_filter) => {
                let parents_predicates = split_conjunction_owned(filter.predicate);

                // remove duplicated filters
                let child_predicates = split_conjunction_owned(child_filter.predicate);
                let new_predicates = parents_predicates
                    .into_iter()
                    .chain(child_predicates)
                    // use IndexSet to remove dupes while preserving predicate order
                    .collect::<IndexSet<_>>()
                    .into_iter()
                    .collect::<Vec<_>>();

                let Some(new_predicate) = conjunction(new_predicates) else {
                    return plan_err!("at least one expression exists");
                };
                let new_filter = LogicalPlan::Filter(Filter::try_new(
                    new_predicate,
                    child_filter.input,
                )?);
                #[allow(clippy::used_underscore_binding)]
                self.rewrite(new_filter, _config)
            }
            LogicalPlan::Repartition(repartition) => {
                let new_filter =
                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
                        .map(LogicalPlan::Filter)?;
                insert_below(LogicalPlan::Repartition(repartition), new_filter)
            }
            LogicalPlan::Distinct(distinct) => {
                let new_filter =
                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
                        .map(LogicalPlan::Filter)?;
                insert_below(LogicalPlan::Distinct(distinct), new_filter)
            }
            LogicalPlan::Sort(sort) => {
                let new_filter =
                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
                        .map(LogicalPlan::Filter)?;
                insert_below(LogicalPlan::Sort(sort), new_filter)
            }
            LogicalPlan::SubqueryAlias(subquery_alias) => {
                let mut replace_map = HashMap::new();
                for (i, (qualifier, field)) in
                    subquery_alias.input.schema().iter().enumerate()
                {
                    let (sub_qualifier, sub_field) =
                        subquery_alias.schema.qualified_field(i);
                    replace_map.insert(
                        qualified_name(sub_qualifier, sub_field.name()),
                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
                    );
                }
                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;

                let new_filter = LogicalPlan::Filter(Filter::try_new(
                    new_predicate,
                    Arc::clone(&subquery_alias.input),
                )?);
                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
            }
            LogicalPlan::Projection(projection) => {
                let predicates = split_conjunction_owned(filter.predicate.clone());
                let (new_projection, keep_predicate) =
                    rewrite_projection(predicates, projection)?;
                if new_projection.transformed {
                    match keep_predicate {
                        None => Ok(new_projection),
                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
                            Filter::try_new(keep_predicate, Arc::new(child_plan))
                                .map(LogicalPlan::Filter)
                        }),
                    }
                } else {
                    filter.input = Arc::new(new_projection.data);
                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
                }
            }
            LogicalPlan::Unnest(mut unnest) => {
                let predicates = split_conjunction_owned(filter.predicate.clone());
                let mut non_unnest_predicates = vec![];
                let mut unnest_predicates = vec![];
                for predicate in predicates {
                    // collect all the Expr::Column in predicate recursively
                    let mut accum: HashSet<Column> = HashSet::new();
                    expr_to_columns(&predicate, &mut accum)?;

                    if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
                        accum.contains(&unnest_list.output_column)
                    }) {
                        unnest_predicates.push(predicate);
                    } else {
                        non_unnest_predicates.push(predicate);
                    }
                }

                // Unnest predicates should not be pushed down.
                // If no non-unnest predicates exist, early return
                if non_unnest_predicates.is_empty() {
                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
                }

                // Push down non-unnest filter predicate
                // Unnest
                //   Unnest Input (Projection)
                // -> rewritten to
                // Unnest
                //   Filter
                //     Unnest Input (Projection)

                let unnest_input = std::mem::take(&mut unnest.input);

                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
                    unnest_input,
                )?);

                // Directly assign new filter plan as the new unnest's input.
                // The new filter plan will go through another rewrite pass since the rule itself
                // is applied recursively to all the child from top to down
                let unnest_plan =
                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;

                match conjunction(unnest_predicates) {
                    None => Ok(unnest_plan),
                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
                    ))),
                }
            }
            LogicalPlan::Union(ref union) => {
                let mut inputs = Vec::with_capacity(union.inputs.len());
                for input in &union.inputs {
                    let mut replace_map = HashMap::new();
                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
                        let (union_qualifier, union_field) =
                            union.schema.qualified_field(i);
                        replace_map.insert(
                            qualified_name(union_qualifier, union_field.name()),
                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
                        );
                    }

                    let push_predicate =
                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
                        push_predicate,
                        Arc::clone(input),
                    )?)))
                }
                Ok(Transformed::yes(LogicalPlan::Union(Union {
                    inputs,
                    schema: Arc::clone(&plan_schema),
                })))
            }
            LogicalPlan::Aggregate(agg) => {
                // We can push down Predicate which in groupby_expr.
                let group_expr_columns = agg
                    .group_expr
                    .iter()
                    .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
                    .collect::<Result<HashSet<_>>>()?;

                let predicates = split_conjunction_owned(filter.predicate);

                let mut keep_predicates = vec![];
                let mut push_predicates = vec![];
                for expr in predicates {
                    let cols = expr.column_refs();
                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
                        push_predicates.push(expr);
                    } else {
                        keep_predicates.push(expr);
                    }
                }

                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
                // After push, we need to replace `a+b` with Column(a)+Column(b)
                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
                let mut replace_map = HashMap::new();
                for expr in &agg.group_expr {
                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
                }
                let replaced_push_predicates = push_predicates
                    .into_iter()
                    .map(|expr| replace_cols_by_name(expr, &replace_map))
                    .collect::<Result<Vec<_>>>()?;

                let agg_input = Arc::clone(&agg.input);
                Transformed::yes(LogicalPlan::Aggregate(agg))
                    .transform_data(|new_plan| {
                        // If we have a filter to push, we push it down to the input of the aggregate
                        if let Some(predicate) = conjunction(replaced_push_predicates) {
                            let new_filter = make_filter(predicate, agg_input)?;
                            insert_below(new_plan, new_filter)
                        } else {
                            Ok(Transformed::no(new_plan))
                        }
                    })?
                    .map_data(|child_plan| {
                        // if there are any remaining predicates we can't push, add them
                        // back as a filter
                        if let Some(predicate) = conjunction(keep_predicates) {
                            make_filter(predicate, Arc::new(child_plan))
                        } else {
                            Ok(child_plan)
                        }
                    })
            }
            // Tries to push filters based on the partition key(s) of the window function(s) used.
            // Example:
            //   Before:
            //     Filter: (a > 1) and (b > 1) and (c > 1)
            //      Window: func() PARTITION BY [a] ...
            //   ---
            //   After:
            //     Filter: (b > 1) and (c > 1)
            //      Window: func() PARTITION BY [a] ...
            //        Filter: (a > 1)
            LogicalPlan::Window(window) => {
                // Retrieve the set of potential partition keys where we can push filters by.
                // Unlike aggregations, where there is only one statement per SELECT, there can be
                // multiple window functions, each with potentially different partition keys.
                // Therefore, we need to ensure that any potential partition key returned is used in
                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
                let extract_partition_keys = |func: &WindowFunction| {
                    func.params
                        .partition_by
                        .iter()
                        .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
                        .collect::<HashSet<_>>()
                };
                let potential_partition_keys = window
                    .window_expr
                    .iter()
                    .map(|e| {
                        match e {
                            Expr::WindowFunction(window_func) => {
                                extract_partition_keys(window_func)
                            }
                            Expr::Alias(alias) => {
                                if let Expr::WindowFunction(window_func) =
                                    alias.expr.as_ref()
                                {
                                    extract_partition_keys(window_func)
                                } else {
                                    // window functions expressions are only Expr::WindowFunction
                                    unreachable!()
                                }
                            }
                            _ => {
                                // window functions expressions are only Expr::WindowFunction
                                unreachable!()
                            }
                        }
                    })
                    // performs the set intersection of the partition keys of all window functions,
                    // returning only the common ones
                    .reduce(|a, b| &a & &b)
                    .unwrap_or_default();

                let predicates = split_conjunction_owned(filter.predicate);
                let mut keep_predicates = vec![];
                let mut push_predicates = vec![];
                for expr in predicates {
                    let cols = expr.column_refs();
                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
                        push_predicates.push(expr);
                    } else {
                        keep_predicates.push(expr);
                    }
                }

                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
                // available as standalone columns to the user. For example, while an aggregation on
                // `a+b` becomes Column(a + b), in a window partition it becomes
                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
                // place, so we can use `push_predicates` directly. This is consistent with other
                // optimizers, such as the one used by Postgres.

                let window_input = Arc::clone(&window.input);
                Transformed::yes(LogicalPlan::Window(window))
                    .transform_data(|new_plan| {
                        // If we have a filter to push, we push it down to the input of the window
                        if let Some(predicate) = conjunction(push_predicates) {
                            let new_filter = make_filter(predicate, window_input)?;
                            insert_below(new_plan, new_filter)
                        } else {
                            Ok(Transformed::no(new_plan))
                        }
                    })?
                    .map_data(|child_plan| {
                        // if there are any remaining predicates we can't push, add them
                        // back as a filter
                        if let Some(predicate) = conjunction(keep_predicates) {
                            make_filter(predicate, Arc::new(child_plan))
                        } else {
                            Ok(child_plan)
                        }
                    })
            }
            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
            LogicalPlan::TableScan(scan) => {
                let filter_predicates = split_conjunction(&filter.predicate);

                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
                    filter_predicates
                        .into_iter()
                        .partition(|pred| pred.is_volatile());

                // Check which non-volatile filters are supported by source
                let supported_filters = scan
                    .source
                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
                if non_volatile_filters.len() != supported_filters.len() {
                    return internal_err!(
                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
                        supported_filters.len(),
                        non_volatile_filters.len());
                }

                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
                let zip = non_volatile_filters.into_iter().zip(supported_filters);

                let new_scan_filters = zip
                    .clone()
                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
                    .map(|(pred, _)| pred);

                // Add new scan filters
                let new_scan_filters: Vec<Expr> = scan
                    .filters
                    .iter()
                    .chain(new_scan_filters)
                    .unique()
                    .cloned()
                    .collect();

                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
                let new_predicate: Vec<Expr> = zip
                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
                    .map(|(pred, _)| pred)
                    .chain(volatile_filters)
                    .cloned()
                    .collect();

                let new_scan = LogicalPlan::TableScan(TableScan {
                    filters: new_scan_filters,
                    ..scan
                });

                Transformed::yes(new_scan).transform_data(|new_scan| {
                    if let Some(predicate) = conjunction(new_predicate) {
                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
                    } else {
                        Ok(Transformed::no(new_scan))
                    }
                })
            }
            LogicalPlan::Extension(extension_plan) => {
                // This check prevents the Filter from being removed when the extension node has no children,
                // so we return the original Filter unchanged.
                if extension_plan.node.inputs().is_empty() {
                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
                }
                let prevent_cols =
                    extension_plan.node.prevent_predicate_push_down_columns();

                // determine if we can push any predicates down past the extension node

                // each element is true for push, false to keep
                let predicate_push_or_keep = split_conjunction(&filter.predicate)
                    .iter()
                    .map(|expr| {
                        let cols = expr.column_refs();
                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
                            Ok(false) // No push (keep)
                        } else {
                            Ok(true) // push
                        }
                    })
                    .collect::<Result<Vec<_>>>()?;

                // all predicates are kept, no changes needed
                if predicate_push_or_keep.iter().all(|&x| !x) {
                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
                }

                // going to push some predicates down, so split the predicates
                let mut keep_predicates = vec![];
                let mut push_predicates = vec![];
                for (push, expr) in predicate_push_or_keep
                    .into_iter()
                    .zip(split_conjunction_owned(filter.predicate).into_iter())
                {
                    if !push {
                        keep_predicates.push(expr);
                    } else {
                        push_predicates.push(expr);
                    }
                }

                let new_children = match conjunction(push_predicates) {
                    Some(predicate) => extension_plan
                        .node
                        .inputs()
                        .into_iter()
                        .map(|child| {
                            Ok(LogicalPlan::Filter(Filter::try_new(
                                predicate.clone(),
                                Arc::new(child.clone()),
                            )?))
                        })
                        .collect::<Result<Vec<_>>>()?,
                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
                };
                // extension with new inputs.
                let child_plan = LogicalPlan::Extension(extension_plan);
                let new_extension =
                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;

                let new_plan = match conjunction(keep_predicates) {
                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
                        predicate,
                        Arc::new(new_extension),
                    )?),
                    None => new_extension,
                };
                Ok(Transformed::yes(new_plan))
            }
            child => {
                filter.input = Arc::new(child);
                Ok(Transformed::no(LogicalPlan::Filter(filter)))
            }
        }
    }