in native/core/src/execution/planner.rs [1605:1916]
fn create_agg_expr(
&self,
spark_expr: &AggExpr,
schema: SchemaRef,
) -> Result<AggregateFunctionExpr, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
AggExprStruct::Count(expr) => {
assert!(!expr.children.is_empty());
// Using `count_udaf` from Comet is exceptionally slow for some reason, so
// as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))`
// https://github.com/apache/datafusion-comet/issues/744
let children = expr
.children
.iter()
.map(|child| self.create_expr(child, Arc::clone(&schema)))
.collect::<Result<Vec<_>, _>>()?;
// create `IS NOT NULL expr` and join them with `AND` if there are multiple
let not_null_expr: Arc<dyn PhysicalExpr> = children.iter().skip(1).fold(
Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc<dyn PhysicalExpr>,
|acc, child| {
Arc::new(BinaryExpr::new(
acc,
DataFusionOperator::And,
Arc::new(IsNotNullExpr::new(Arc::clone(child))),
))
},
);
let child = Arc::new(IfExpr::new(
not_null_expr,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
));
AggregateExprBuilder::new(sum_udaf(), vec![child])
.schema(schema)
.alias("count")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Min(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
AggregateExprBuilder::new(min_udaf(), vec![child])
.schema(schema)
.alias("min")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Max(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
AggregateExprBuilder::new(max_udaf(), vec![child])
.schema(schema)
.alias("max")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Sum(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?);
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of SUM if necessary, we should not expect
// a cast failure since it should have already been checked at Spark side
let child =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
AggregateExprBuilder::new(sum_udaf(), vec![child])
}
};
builder
.schema(schema)
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Avg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of AVG if the result data type is different
// from the input type, e.g. AVG(Int32). We should not expect a cast
// failure since it should have already been checked at Spark side.
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
builder
.schema(schema)
.alias("avg")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::First(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(FirstValue::new());
AggregateExprBuilder::new(Arc::new(func), vec![child])
.schema(schema)
.alias("first")
.with_ignore_nulls(expr.ignore_nulls)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Last(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(LastValue::new());
AggregateExprBuilder::new(Arc::new(func), vec![child])
.schema(schema)
.alias("last")
.with_ignore_nulls(expr.ignore_nulls)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitAndAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_and_udaf(), vec![child])
.schema(schema)
.alias("bit_and")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitOrAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_or_udaf(), vec![child])
.schema(schema)
.alias("bit_or")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitXorAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_xor_udaf(), vec![child])
.schema(schema)
.alias("bit_xor")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Covariance(expr) => {
let child1 =
self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?;
let child2 =
self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Covariance::new(
"covariance",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr(
"covariance",
schema,
vec![child1, child2],
func,
)
}
1 => {
let func = AggregateUDF::new_from_impl(Covariance::new(
"covariance_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr(
"covariance_pop",
schema,
vec![child1, child2],
func,
)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for Variance",
stats_type
))),
}
}
AggExprStruct::Variance(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Variance::new(
"variance",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("variance", schema, vec![child], func)
}
1 => {
let func = AggregateUDF::new_from_impl(Variance::new(
"variance_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("variance_pop", schema, vec![child], func)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for Variance",
stats_type
))),
}
}
AggExprStruct::Stddev(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Stddev::new(
"stddev",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("stddev", schema, vec![child], func)
}
1 => {
let func = AggregateUDF::new_from_impl(Stddev::new(
"stddev_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("stddev_pop", schema, vec![child], func)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for stddev",
stats_type
))),
}
}
AggExprStruct::Correlation(expr) => {
let child1 =
self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?;
let child2 =
self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(Correlation::new(
"correlation",
datatype,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::BloomFilterAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let num_items =
self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?;
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&num_items),
Arc::clone(&num_bits),
datatype,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
}
}