in datafusion/optimizer/src/push_down_filter.rs [766:1221]
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if let LogicalPlan::Join(join) = plan {
return push_down_join(join, None);
};
let plan_schema = Arc::clone(plan.schema());
let LogicalPlan::Filter(mut filter) = plan else {
return Ok(Transformed::no(plan));
};
match Arc::unwrap_or_clone(filter.input) {
LogicalPlan::Filter(child_filter) => {
let parents_predicates = split_conjunction_owned(filter.predicate);
// remove duplicated filters
let child_predicates = split_conjunction_owned(child_filter.predicate);
let new_predicates = parents_predicates
.into_iter()
.chain(child_predicates)
// use IndexSet to remove dupes while preserving predicate order
.collect::<IndexSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let Some(new_predicate) = conjunction(new_predicates) else {
return plan_err!("at least one expression exists");
};
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input,
)?);
#[allow(clippy::used_underscore_binding)]
self.rewrite(new_filter, _config)
}
LogicalPlan::Repartition(repartition) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Repartition(repartition), new_filter)
}
LogicalPlan::Distinct(distinct) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Distinct(distinct), new_filter)
}
LogicalPlan::Sort(sort) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&sort.input))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Sort(sort), new_filter)
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
{
let (sub_qualifier, sub_field) =
subquery_alias.schema.qualified_field(i);
replace_map.insert(
qualified_name(sub_qualifier, sub_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
Arc::clone(&subquery_alias.input),
)?);
insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
}
LogicalPlan::Projection(projection) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let (new_projection, keep_predicate) =
rewrite_projection(predicates, projection)?;
if new_projection.transformed {
match keep_predicate {
None => Ok(new_projection),
Some(keep_predicate) => new_projection.map_data(|child_plan| {
Filter::try_new(keep_predicate, Arc::new(child_plan))
.map(LogicalPlan::Filter)
}),
}
} else {
filter.input = Arc::new(new_projection.data);
Ok(Transformed::no(LogicalPlan::Filter(filter)))
}
}
LogicalPlan::Unnest(mut unnest) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let mut non_unnest_predicates = vec![];
let mut unnest_predicates = vec![];
for predicate in predicates {
// collect all the Expr::Column in predicate recursively
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut accum)?;
if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
accum.contains(&unnest_list.output_column)
}) {
unnest_predicates.push(predicate);
} else {
non_unnest_predicates.push(predicate);
}
}
// Unnest predicates should not be pushed down.
// If no non-unnest predicates exist, early return
if non_unnest_predicates.is_empty() {
filter.input = Arc::new(LogicalPlan::Unnest(unnest));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
// Push down non-unnest filter predicate
// Unnest
// Unnest Input (Projection)
// -> rewritten to
// Unnest
// Filter
// Unnest Input (Projection)
let unnest_input = std::mem::take(&mut unnest.input);
let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
unnest_input,
)?);
// Directly assign new filter plan as the new unnest's input.
// The new filter plan will go through another rewrite pass since the rule itself
// is applied recursively to all the child from top to down
let unnest_plan =
insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
match conjunction(unnest_predicates) {
None => Ok(unnest_plan),
Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
))),
}
}
LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in input.schema().iter().enumerate() {
let (union_qualifier, union_field) =
union.schema.qualified_field(i);
replace_map.insert(
qualified_name(union_qualifier, union_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let push_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
Arc::clone(input),
)?)))
}
Ok(Transformed::yes(LogicalPlan::Union(Union {
inputs,
schema: Arc::clone(&plan_schema),
})))
}
LogicalPlan::Aggregate(agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
.iter()
.map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
.collect::<Result<HashSet<_>>>()?;
let predicates = split_conjunction_owned(filter.predicate);
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.column_refs();
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}
// As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
// After push, we need to replace `a+b` with Column(a)+Column(b)
// So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.schema_name().to_string(), expr.clone());
}
let replaced_push_predicates = push_predicates
.into_iter()
.map(|expr| replace_cols_by_name(expr, &replace_map))
.collect::<Result<Vec<_>>>()?;
let agg_input = Arc::clone(&agg.input);
Transformed::yes(LogicalPlan::Aggregate(agg))
.transform_data(|new_plan| {
// If we have a filter to push, we push it down to the input of the aggregate
if let Some(predicate) = conjunction(replaced_push_predicates) {
let new_filter = make_filter(predicate, agg_input)?;
insert_below(new_plan, new_filter)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
// if there are any remaining predicates we can't push, add them
// back as a filter
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
}
// Tries to push filters based on the partition key(s) of the window function(s) used.
// Example:
// Before:
// Filter: (a > 1) and (b > 1) and (c > 1)
// Window: func() PARTITION BY [a] ...
// ---
// After:
// Filter: (b > 1) and (c > 1)
// Window: func() PARTITION BY [a] ...
// Filter: (a > 1)
LogicalPlan::Window(window) => {
// Retrieve the set of potential partition keys where we can push filters by.
// Unlike aggregations, where there is only one statement per SELECT, there can be
// multiple window functions, each with potentially different partition keys.
// Therefore, we need to ensure that any potential partition key returned is used in
// ALL window functions. Otherwise, filters cannot be pushed by through that column.
let extract_partition_keys = |func: &WindowFunction| {
func.params
.partition_by
.iter()
.map(|c| Column::from_qualified_name(c.schema_name().to_string()))
.collect::<HashSet<_>>()
};
let potential_partition_keys = window
.window_expr
.iter()
.map(|e| {
match e {
Expr::WindowFunction(window_func) => {
extract_partition_keys(window_func)
}
Expr::Alias(alias) => {
if let Expr::WindowFunction(window_func) =
alias.expr.as_ref()
{
extract_partition_keys(window_func)
} else {
// window functions expressions are only Expr::WindowFunction
unreachable!()
}
}
_ => {
// window functions expressions are only Expr::WindowFunction
unreachable!()
}
}
})
// performs the set intersection of the partition keys of all window functions,
// returning only the common ones
.reduce(|a, b| &a & &b)
.unwrap_or_default();
let predicates = split_conjunction_owned(filter.predicate);
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.column_refs();
if cols.iter().all(|c| potential_partition_keys.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}
// Unlike with aggregations, there are no cases where we have to replace, e.g.,
// `a+b` with Column(a)+Column(b). This is because partition expressions are not
// available as standalone columns to the user. For example, while an aggregation on
// `a+b` becomes Column(a + b), in a window partition it becomes
// `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
// place, so we can use `push_predicates` directly. This is consistent with other
// optimizers, such as the one used by Postgres.
let window_input = Arc::clone(&window.input);
Transformed::yes(LogicalPlan::Window(window))
.transform_data(|new_plan| {
// If we have a filter to push, we push it down to the input of the window
if let Some(predicate) = conjunction(push_predicates) {
let new_filter = make_filter(predicate, window_input)?;
insert_below(new_plan, new_filter)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
// if there are any remaining predicates we can't push, add them
// back as a filter
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
}
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
LogicalPlan::TableScan(scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
filter_predicates
.into_iter()
.partition(|pred| pred.is_volatile());
// Check which non-volatile filters are supported by source
let supported_filters = scan
.source
.supports_filters_pushdown(non_volatile_filters.as_slice())?;
if non_volatile_filters.len() != supported_filters.len() {
return internal_err!(
"Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
supported_filters.len(),
non_volatile_filters.len());
}
// Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
let zip = non_volatile_filters.into_iter().zip(supported_filters);
let new_scan_filters = zip
.clone()
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
.map(|(pred, _)| pred);
// Add new scan filters
let new_scan_filters: Vec<Expr> = scan
.filters
.iter()
.chain(new_scan_filters)
.unique()
.cloned()
.collect();
// Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
let new_predicate: Vec<Expr> = zip
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
.map(|(pred, _)| pred)
.chain(volatile_filters)
.cloned()
.collect();
let new_scan = LogicalPlan::TableScan(TableScan {
filters: new_scan_filters,
..scan
});
Transformed::yes(new_scan).transform_data(|new_scan| {
if let Some(predicate) = conjunction(new_predicate) {
make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
} else {
Ok(Transformed::no(new_scan))
}
})
}
LogicalPlan::Extension(extension_plan) => {
// This check prevents the Filter from being removed when the extension node has no children,
// so we return the original Filter unchanged.
if extension_plan.node.inputs().is_empty() {
filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
// determine if we can push any predicates down past the extension node
// each element is true for push, false to keep
let predicate_push_or_keep = split_conjunction(&filter.predicate)
.iter()
.map(|expr| {
let cols = expr.column_refs();
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
Ok(false) // No push (keep)
} else {
Ok(true) // push
}
})
.collect::<Result<Vec<_>>>()?;
// all predicates are kept, no changes needed
if predicate_push_or_keep.iter().all(|&x| !x) {
filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
// going to push some predicates down, so split the predicates
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for (push, expr) in predicate_push_or_keep
.into_iter()
.zip(split_conjunction_owned(filter.predicate).into_iter())
{
if !push {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
let new_children = match conjunction(push_predicates) {
Some(predicate) => extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(child.clone()),
)?))
})
.collect::<Result<Vec<_>>>()?,
None => extension_plan.node.inputs().into_iter().cloned().collect(),
};
// extension with new inputs.
let child_plan = LogicalPlan::Extension(extension_plan);
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
let new_plan = match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
None => new_extension,
};
Ok(Transformed::yes(new_plan))
}
child => {
filter.input = Arc::new(child);
Ok(Transformed::no(LogicalPlan::Filter(filter)))
}
}
}