native/core/src/execution/planner.rs (2,783 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Converts Spark physical plan to DataFusion physical plan
use super::expressions::EvalMode;
use crate::execution::operators::CopyMode;
use crate::execution::operators::FilterExec as CometFilterExec;
use crate::{
errors::ExpressionError,
execution::{
expressions::{
bloom_filter_agg::BloomFilterAgg, bloom_filter_might_contain::BloomFilterMightContain,
subquery::Subquery,
},
operators::{CopyExec, ExecutionError, ExpandExec, ScanExec},
serde::to_arrow_datatype,
shuffle::ShuffleWriterExec,
},
};
use arrow::compute::CastOptions;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::min_max::min_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
execution::FunctionRegistry,
functions_aggregate::first_last::{FirstValue, LastValue},
logical_expr::Operator as DataFusionOperator,
physical_expr::{
expressions::{
in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr,
Literal as DataFusionLiteral, NotExpr,
},
PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
},
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
prelude::SessionContext,
};
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
use crate::execution::operators::ExecutionError::GeneralError;
use crate::execution::shuffle::CompressionCodec;
use crate::execution::spark_plan::SparkPlan;
use crate::parquet::parquet_exec::init_datasource_exec;
use crate::parquet::parquet_support::prepare_object_store;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::logical_expr::type_coercion::other::get_coerce_type_for_case_expression;
use datafusion::logical_expr::{
AggregateUDF, ReturnTypeArgs, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion::physical_expr::expressions::{Literal, StatsType};
use datafusion::physical_expr::window::WindowExpr;
use datafusion::physical_expr::LexOrdering;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
use datafusion_comet_proto::spark_operator::SparkFilePartition;
use datafusion_comet_proto::{
spark_expression::{
self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr,
Expr, ScalarFunc,
},
spark_operator::{
self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct,
upper_window_frame_bound::UpperFrameBoundStruct, BuildSide,
CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType,
},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation,
Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField,
HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr,
SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
TimestampTruncExpr, ToJson, UnboundColumn, Variance,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
use object_store::path::Path;
use std::cmp::max;
use std::{collections::HashMap, sync::Arc};
use url::Url;
// For clippy error on type_complexity.
type PhyAggResult = Result<Vec<AggregateFunctionExpr>, ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError>;
struct JoinParameters {
pub left: Arc<SparkPlan>,
pub right: Arc<SparkPlan>,
pub join_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
pub join_filter: Option<JoinFilter>,
pub join_type: DFJoinType,
}
#[derive(Default)]
struct BinaryExprOptions {
pub is_integral_div: bool,
}
pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
/// The query planner for converting Spark query plans to DataFusion query plans.
pub struct PhysicalPlanner {
// The execution context id of this planner.
exec_context_id: i64,
session_ctx: Arc<SessionContext>,
}
impl Default for PhysicalPlanner {
fn default() -> Self {
let session_ctx = Arc::new(SessionContext::new());
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
session_ctx,
}
}
}
impl PhysicalPlanner {
pub fn new(session_ctx: Arc<SessionContext>) -> Self {
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
session_ctx,
}
}
pub fn with_exec_id(self, exec_context_id: i64) -> Self {
Self {
exec_context_id,
session_ctx: Arc::clone(&self.session_ctx),
}
}
/// Return session context of this planner.
pub fn session_ctx(&self) -> &Arc<SessionContext> {
&self.session_ctx
}
/// get DataFusion PartitionedFiles from a Spark FilePartition
fn get_partitioned_files(
&self,
partition: &SparkFilePartition,
) -> Result<Vec<PartitionedFile>, ExecutionError> {
let mut files = Vec::with_capacity(partition.partitioned_file.len());
partition.partitioned_file.iter().try_for_each(|file| {
assert!(file.start + file.length <= file.file_size);
let mut partitioned_file = PartitionedFile::new_with_range(
String::new(), // Dummy file path.
file.file_size as u64,
file.start,
file.start + file.length,
);
// Spark sends the path over as URL-encoded, parse that first.
let url =
Url::parse(file.file_path.as_ref()).map_err(|e| GeneralError(e.to_string()))?;
// Convert that to a Path object to use in the PartitionedFile.
let path = Path::from_url_path(url.path()).map_err(|e| GeneralError(e.to_string()))?;
partitioned_file.object_meta.location = path;
// Process partition values
// Create an empty input schema for partition values because they are all literals.
let empty_schema = Arc::new(Schema::empty());
let partition_values: Result<Vec<_>, _> = file
.partition_values
.iter()
.map(|partition_value| {
let literal =
self.create_expr(partition_value, Arc::<Schema>::clone(&empty_schema))?;
literal
.as_any()
.downcast_ref::<DataFusionLiteral>()
.ok_or_else(|| {
GeneralError("Expected literal of partition value".to_string())
})
.map(|literal| literal.value().clone())
})
.collect();
let partition_values = partition_values?;
partitioned_file.partition_values = partition_values;
files.push(partitioned_file);
Ok::<(), ExecutionError>(())
})?;
Ok(files)
}
/// Create a DataFusion physical expression from Spark physical expression
pub(crate) fn create_expr(
&self,
spark_expr: &Expr,
input_schema: SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
ExprStruct::Add(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Plus,
input_schema,
),
ExprStruct::Subtract(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Minus,
input_schema,
),
ExprStruct::Multiply(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Multiply,
input_schema,
),
ExprStruct::Divide(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Divide,
input_schema,
),
ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Divide,
input_schema,
BinaryExprOptions {
is_integral_div: true,
},
),
ExprStruct::Remainder(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Modulo,
input_schema,
),
ExprStruct::Eq(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::Eq;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Neq(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::NotEq;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Gt(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::Gt;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::GtEq(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::GtEq;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Lt(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::Lt;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::LtEq(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::LtEq;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Bound(bound) => {
let idx = bound.index as usize;
if idx >= input_schema.fields().len() {
return Err(GeneralError(format!(
"Column index {} is out of bound. Schema: {}",
idx, input_schema
)));
}
let field = input_schema.field(idx);
Ok(Arc::new(Column::new(field.name().as_str(), idx)))
}
ExprStruct::Unbound(unbound) => {
let data_type = to_arrow_datatype(unbound.datatype.as_ref().unwrap());
Ok(Arc::new(UnboundColumn::new(
unbound.name.as_str(),
data_type,
)))
}
ExprStruct::IsNotNull(is_notnull) => {
let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(IsNotNullExpr::new(child)))
}
ExprStruct::IsNull(is_null) => {
let child = self.create_expr(is_null.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(IsNullExpr::new(child)))
}
ExprStruct::And(and) => {
let left =
self.create_expr(and.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(and.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::And;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Or(or) => {
let left =
self.create_expr(or.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(or.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::Or;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::Literal(literal) => {
let data_type = to_arrow_datatype(literal.datatype.as_ref().unwrap());
let scalar_value = if literal.is_null {
match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Int8 => ScalarValue::Int8(None),
DataType::Int16 => ScalarValue::Int16(None),
DataType::Int32 => ScalarValue::Int32(None),
DataType::Int64 => ScalarValue::Int64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::Date32 => ScalarValue::Date32(None),
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
ScalarValue::TimestampMicrosecond(None, timezone)
}
DataType::Binary => ScalarValue::Binary(None),
DataType::Decimal128(p, s) => ScalarValue::Decimal128(None, p, s),
DataType::Struct(fields) => ScalarStructBuilder::new_null(fields),
DataType::Null => ScalarValue::Null,
dt => {
return Err(GeneralError(format!("{:?} is not supported in Comet", dt)))
}
}
} else {
match literal.value.as_ref().unwrap() {
Value::BoolVal(value) => ScalarValue::Boolean(Some(*value)),
Value::ByteVal(value) => ScalarValue::Int8(Some(*value as i8)),
Value::ShortVal(value) => ScalarValue::Int16(Some(*value as i16)),
Value::IntVal(value) => match data_type {
DataType::Int32 => ScalarValue::Int32(Some(*value)),
DataType::Date32 => ScalarValue::Date32(Some(*value)),
dt => {
return Err(GeneralError(format!(
"Expected either 'Int32' or 'Date32' for IntVal, but found {:?}",
dt
)))
}
},
Value::LongVal(value) => match data_type {
DataType::Int64 => ScalarValue::Int64(Some(*value)),
DataType::Timestamp(TimeUnit::Microsecond, None) => {
ScalarValue::TimestampMicrosecond(Some(*value), None)
}
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
ScalarValue::TimestampMicrosecond(Some(*value), Some(tz))
}
dt => {
return Err(GeneralError(format!(
"Expected either 'Int64' or 'Timestamp' for LongVal, but found {:?}",
dt
)))
}
},
Value::FloatVal(value) => ScalarValue::Float32(Some(*value)),
Value::DoubleVal(value) => ScalarValue::Float64(Some(*value)),
Value::StringVal(value) => ScalarValue::Utf8(Some(value.clone())),
Value::BytesVal(value) => ScalarValue::Binary(Some(value.clone())),
Value::DecimalVal(value) => {
let big_integer = BigInt::from_signed_bytes_be(value);
let integer = big_integer.to_i128().ok_or_else(|| {
GeneralError(format!(
"Cannot parse {:?} as i128 for Decimal literal",
big_integer
))
})?;
match data_type {
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(integer), p, s)
}
dt => {
return Err(GeneralError(format!(
"Decimal literal's data type should be Decimal128 but got {:?}",
dt
)))
}
}
}
}
};
Ok(Arc::new(DataFusionLiteral::new(scalar_value)))
}
ExprStruct::Cast(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
Ok(Arc::new(Cast::new(
child,
datatype,
SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat),
)))
}
ExprStruct::Hour(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let timezone = expr.timezone.clone();
Ok(Arc::new(HourExpr::new(child, timezone)))
}
ExprStruct::Minute(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let timezone = expr.timezone.clone();
Ok(Arc::new(MinuteExpr::new(child, timezone)))
}
ExprStruct::Second(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let timezone = expr.timezone.clone();
Ok(Arc::new(SecondExpr::new(child, timezone)))
}
ExprStruct::TruncDate(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let format = self.create_expr(expr.format.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(DateTruncExpr::new(child, format)))
}
ExprStruct::TruncTimestamp(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let format = self.create_expr(expr.format.as_ref().unwrap(), input_schema)?;
let timezone = expr.timezone.clone();
Ok(Arc::new(TimestampTruncExpr::new(child, format, timezone)))
}
ExprStruct::Substring(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
// Spark Substring's start is 1-based when start > 0
let start = expr.start - i32::from(expr.start > 0);
// substring negative len is treated as 0 in Spark
let len = max(expr.len, 0);
Ok(Arc::new(SubstringExpr::new(
child,
start as i64,
len as u64,
)))
}
ExprStruct::StringSpace(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(StringSpaceExpr::new(child)))
}
ExprStruct::Contains(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(Contains::new(left, right)))
}
ExprStruct::StartsWith(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(StartsWith::new(left, right)))
}
ExprStruct::EndsWith(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(EndsWith::new(left, right)))
}
ExprStruct::Like(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(Like::new(left, right)))
}
ExprStruct::Rlike(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
match right.as_any().downcast_ref::<Literal>().unwrap().value() {
ScalarValue::Utf8(Some(pattern)) => {
Ok(Arc::new(RLike::try_new(left, pattern)?))
}
_ => Err(GeneralError(
"RLike only supports scalar patterns".to_string(),
)),
}
}
ExprStruct::CheckOverflow(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let fail_on_error = expr.fail_on_error;
Ok(Arc::new(CheckOverflow::new(
child,
data_type,
fail_on_error,
)))
}
ExprStruct::ScalarFunc(expr) => self.create_scalar_function_expr(expr, input_schema),
ExprStruct::EqNullSafe(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::IsNotDistinctFrom;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::NeqNullSafe(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::IsDistinctFrom;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::BitwiseAnd(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::BitwiseAnd;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::BitwiseNot(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(BitwiseNotExpr::new(child)))
}
ExprStruct::BitwiseOr(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::BitwiseOr;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::BitwiseXor(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::BitwiseXor;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::BitwiseShiftRight(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::BitwiseShiftRight;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
ExprStruct::BitwiseShiftLeft(expr) => {
let left =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let op = DataFusionOperator::BitwiseShiftLeft;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
// https://github.com/apache/datafusion-comet/issues/666
// ExprStruct::Abs(expr) => {
// let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
// let return_type = child.data_type(&input_schema)?;
// let args = vec![child];
// let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
// let comet_abs = Arc::new(ScalarUDF::new_from_impl(Abs::new(
// eval_mode,
// return_type.to_string(),
// )?));
// let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type);
// Ok(Arc::new(expr))
// }
ExprStruct::CaseWhen(case_when) => {
let when_then_pairs = case_when
.when
.iter()
.map(|x| self.create_expr(x, Arc::clone(&input_schema)))
.zip(
case_when
.then
.iter()
.map(|then| self.create_expr(then, Arc::clone(&input_schema))),
)
.try_fold(Vec::new(), |mut acc, (a, b)| {
acc.push((a?, b?));
Ok::<Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>, ExecutionError>(
acc,
)
})?;
let else_phy_expr = match &case_when.else_expr {
None => None,
Some(_) => Some(self.create_expr(
case_when.else_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?),
};
create_case_expr(when_then_pairs, else_phy_expr, &input_schema)
}
ExprStruct::In(expr) => {
let value =
self.create_expr(expr.in_value.as_ref().unwrap(), Arc::clone(&input_schema))?;
let list = expr
.lists
.iter()
.map(|x| self.create_expr(x, Arc::clone(&input_schema)))
.collect::<Result<Vec<_>, _>>()?;
in_list(value, list, &expr.negated, input_schema.as_ref()).map_err(|e| e.into())
}
ExprStruct::If(expr) => {
let if_expr =
self.create_expr(expr.if_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let true_expr =
self.create_expr(expr.true_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let false_expr =
self.create_expr(expr.false_expr.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
}
ExprStruct::Not(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(NotExpr::new(child)))
}
ExprStruct::UnaryMinus(expr) => {
let child: Arc<dyn PhysicalExpr> =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let result = create_negate_expr(child, expr.fail_on_error);
result.map_err(|e| GeneralError(e.to_string()))
}
ExprStruct::NormalizeNanAndZero(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(NormalizeNaNAndZero::new(data_type, child)))
}
ExprStruct::Subquery(expr) => {
let id = expr.id;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type)))
}
ExprStruct::BloomFilterMightContain(expr) => {
let bloom_filter_expr = self.create_expr(
expr.bloom_filter.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(BloomFilterMightContain::try_new(
bloom_filter_expr,
value_expr,
)?))
}
ExprStruct::CreateNamedStruct(expr) => {
let values = expr
.values
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
.collect::<Result<Vec<_>, _>>()?;
let names = expr.names.clone();
Ok(Arc::new(CreateNamedStruct::new(values, names)))
}
ExprStruct::GetStructField(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
Ok(Arc::new(GetStructField::new(child, expr.ordinal as usize)))
}
ExprStruct::ToJson(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(ToJson::new(child, &expr.timezone)))
}
ExprStruct::ListExtract(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let ordinal =
self.create_expr(expr.ordinal.as_ref().unwrap(), Arc::clone(&input_schema))?;
let default_value = expr
.default_value
.as_ref()
.map(|e| self.create_expr(e, Arc::clone(&input_schema)))
.transpose()?;
Ok(Arc::new(ListExtract::new(
child,
ordinal,
default_value,
expr.one_based,
expr.fail_on_error,
)))
}
ExprStruct::GetArrayStructFields(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
Ok(Arc::new(GetArrayStructFields::new(
child,
expr.ordinal as usize,
)))
}
ExprStruct::ArrayInsert(expr) => {
let src_array_expr = self.create_expr(
expr.src_array_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
let pos_expr =
self.create_expr(expr.pos_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let item_expr =
self.create_expr(expr.item_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
Ok(Arc::new(ArrayInsert::new(
src_array_expr,
pos_expr,
item_expr,
expr.legacy_negative_index,
)))
}
expr => Err(GeneralError(format!("Not implemented: {:?}", expr))),
}
}
/// Create a DataFusion physical sort expression from Spark physical expression
fn create_sort_expr<'a>(
&'a self,
spark_expr: &'a Expr,
input_schema: SchemaRef,
) -> Result<PhysicalSortExpr, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
ExprStruct::SortOrder(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let descending = expr.direction == 1;
let nulls_first = expr.null_ordering == 0;
let options = SortOptions {
descending,
nulls_first,
};
Ok(PhysicalSortExpr {
expr: child,
options,
})
}
expr => Err(GeneralError(format!("{:?} isn't a SortOrder", expr))),
}
}
fn create_binary_expr(
&self,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
self.create_binary_expr_with_options(
left,
right,
return_type,
op,
input_schema,
BinaryExprOptions::default(),
)
}
fn create_binary_expr_with_options(
&self,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
options: BinaryExprOptions,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let left = self.create_expr(left, Arc::clone(&input_schema))?;
let right = self.create_expr(right, Arc::clone(&input_schema))?;
match (
&op,
left.data_type(&input_schema),
right.data_type(&input_schema),
) {
(
DataFusionOperator::Plus
| DataFusionOperator::Minus
| DataFusionOperator::Multiply
| DataFusionOperator::Modulo,
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Modulo
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
> DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new(
child,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
)))
}
(
DataFusionOperator::Divide,
Ok(DataType::Decimal128(_p1, _s1)),
Ok(DataType::Decimal128(_p2, _s2)),
) => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
let func_name = if options.is_integral_div {
// Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div.
// Otherwise, we may be able to reuse the previous case-match instead of here,
// see more: https://github.com/apache/datafusion-comet/pull/1428#discussion_r1972648463
"decimal_integral_div"
} else {
"decimal_div"
};
let fun_expr = create_comet_physical_fun(
func_name,
data_type.clone(),
&self.session_ctx.state(),
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
func_name,
fun_expr,
vec![left, right],
data_type,
)))
}
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
}
}
/// Create a DataFusion physical plan from Spark physical plan. There is a level of
/// abstraction where a tree of SparkPlan nodes is returned. There is a 1:1 mapping from a
/// protobuf Operator (that represents a Spark operator) to a native SparkPlan struct. We
/// need this 1:1 mapping so that we can report metrics back to Spark. The native execution
/// plan that is generated for each Operator is sometimes a single ExecutionPlan, but in some
/// cases we generate a tree of ExecutionPlans and we need to collect metrics for all of these
/// plans so we store references to them in the SparkPlan struct.
///
/// `inputs` is a vector of input source IDs. It is used to create `ScanExec`s. Each `ScanExec`
/// will be assigned a unique ID from `inputs` and the ID will be used to identify the input
/// source at JNI API.
///
/// Note that `ScanExec` will pull initial input batch during initialization. It is because we
/// need to know the exact schema (not only data type but also dictionary-encoding) at
/// `ScanExec`s. It is because some DataFusion operators, e.g., `ProjectionExec`, gets child
/// operator schema during initialization and uses it later for `RecordBatch`. We may be
/// able to get rid of it once `RecordBatch` relaxes schema check.
///
/// Note that we return created `Scan`s which will be kept at JNI API. JNI calls will use it to
/// feed in new input batch from Spark JVM side.
pub(crate) fn create_plan<'a>(
&'a self,
spark_plan: &'a Operator,
inputs: &mut Vec<Arc<GlobalRef>>,
partition_count: usize,
) -> Result<(Vec<ScanExec>, Arc<SparkPlan>), ExecutionError> {
let children = &spark_plan.children;
match spark_plan.op_struct.as_ref().unwrap() {
OpStruct::Projection(project) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let exprs: PhyExprResult = project
.project_list
.iter()
.enumerate()
.map(|(idx, expr)| {
self.create_expr(expr, child.schema())
.map(|r| (r, format!("col_{}", idx)))
})
.collect();
let projection = Arc::new(ProjectionExec::try_new(
exprs?,
Arc::clone(&child.native_plan),
)?);
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])),
))
}
OpStruct::Filter(filter) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let predicate =
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;
let filter: Arc<dyn ExecutionPlan> = if filter.use_datafusion_filter {
Arc::new(DataFusionFilterExec::try_new(
predicate,
Arc::clone(&child.native_plan),
)?)
} else {
Arc::new(CometFilterExec::try_new(
predicate,
Arc::clone(&child.native_plan),
)?)
};
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])),
))
}
OpStruct::HashAgg(agg) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let group_exprs: PhyExprResult = agg
.grouping_exprs
.iter()
.enumerate()
.map(|(idx, expr)| {
self.create_expr(expr, child.schema())
.map(|r| (r, format!("col_{}", idx)))
})
.collect();
let group_by = PhysicalGroupBy::new_single(group_exprs?);
let schema = child.schema();
let mode = if agg.mode == 0 {
DFAggregateMode::Partial
} else {
DFAggregateMode::Final
};
let agg_exprs: PhyAggResult = agg
.agg_exprs
.iter()
.map(|expr| self.create_agg_expr(expr, Arc::clone(&schema)))
.collect();
let num_agg = agg.agg_exprs.len();
let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();
let aggregate: Arc<dyn ExecutionPlan> = Arc::new(
datafusion::physical_plan::aggregates::AggregateExec::try_new(
mode,
group_by,
aggr_expr,
vec![None; num_agg], // no filter expressions
Arc::clone(&child.native_plan),
Arc::clone(&schema),
)?,
);
let result_exprs: PhyExprResult = agg
.result_exprs
.iter()
.enumerate()
.map(|(idx, expr)| {
self.create_expr(expr, aggregate.schema())
.map(|r| (r, format!("col_{}", idx)))
})
.collect();
if agg.result_exprs.is_empty() {
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])),
))
} else {
// For final aggregation, DF's hash aggregate exec doesn't support Spark's
// aggregate result expressions like `COUNT(col) + 1`, but instead relying
// on additional `ProjectionExec` to handle the case. Therefore, here we'll
// add a projection node on top of the aggregate node.
//
// Note that `result_exprs` should only be set for final aggregation on the
// Spark side.
let projection = Arc::new(ProjectionExec::try_new(
result_exprs?,
Arc::clone(&aggregate),
)?);
Ok((
scans,
Arc::new(SparkPlan::new_with_additional(
spark_plan.plan_id,
projection,
vec![child],
vec![aggregate],
)),
))
}
}
OpStruct::Limit(limit) => {
assert!(children.len() == 1);
let num = limit.limit;
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let limit = Arc::new(LocalLimitExec::new(
Arc::clone(&child.native_plan),
num as usize,
));
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])),
))
}
OpStruct::Sort(sort) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = sort
.sort_orders
.iter()
.map(|expr| self.create_sort_expr(expr, child.schema()))
.collect();
let fetch = sort.fetch.map(|num| num as usize);
// SortExec caches batches so we need to make a copy of incoming batches. Also,
// SortExec fails in some cases if we do not unpack dictionary-encoded arrays, and
// it would be more efficient if we could avoid that.
// https://github.com/apache/datafusion-comet/issues/963
let child_copied = Self::wrap_in_copy_exec(Arc::clone(&child.native_plan));
let sort = Arc::new(
SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child_copied))
.with_fetch(fetch),
);
Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
sort,
vec![Arc::clone(&child)],
)),
))
}
OpStruct::NativeScan(scan) => {
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
let required_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
let partition_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
let projection_vector: Vec<usize> = scan
.projection_vector
.iter()
.map(|offset| *offset as usize)
.collect();
// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
.data_filters
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
.collect();
// Get one file from the list of files
let one_file = scan
.file_partitions
.first()
.and_then(|f| f.partitioned_file.first())
.map(|f| f.file_path.clone())
.ok_or(GeneralError("Failed to locate file".to_string()))?;
let (object_store_url, _) =
prepare_object_store(self.session_ctx.runtime_env(), one_file)?;
// Generate file groups
let mut file_groups: Vec<Vec<PartitionedFile>> =
Vec::with_capacity(partition_count);
scan.file_partitions.iter().try_for_each(|partition| {
let files = self.get_partitioned_files(partition)?;
file_groups.push(files);
Ok::<(), ExecutionError>(())
})?;
// TODO: I think we can remove partition_count in the future, but leave for testing.
assert_eq!(file_groups.len(), partition_count);
let partition_fields: Vec<Field> = partition_schema
.fields()
.iter()
.map(|field| {
Field::new(field.name(), field.data_type().clone(), field.is_nullable())
})
.collect_vec();
let scan = init_datasource_exec(
required_schema,
Some(data_schema),
Some(partition_schema),
Some(partition_fields),
object_store_url,
file_groups,
Some(projection_vector),
Some(data_filters?),
scan.session_timezone.as_str(),
)?;
Ok((
vec![],
Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])),
))
}
OpStruct::Scan(scan) => {
let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec();
// If it is not test execution context for unit test, we should have at least one
// input source
if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() {
return Err(GeneralError("No input for scan".to_string()));
}
// Consumes the first input source for the scan
let input_source =
if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() {
// For unit test, we will set input batch to scan directly by `set_input_batch`.
None
} else {
Some(inputs.remove(0))
};
// The `ScanExec` operator will take actual arrays from Spark during execution
let scan =
ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?;
Ok((
vec![scan.clone()],
Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])),
))
}
OpStruct::ShuffleWriter(writer) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let partitioning = self
.create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?;
let codec = match writer.codec.try_into() {
Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None),
Ok(SparkCompressionCodec::Snappy) => Ok(CompressionCodec::Snappy),
Ok(SparkCompressionCodec::Zstd) => {
Ok(CompressionCodec::Zstd(writer.compression_level))
}
Ok(SparkCompressionCodec::Lz4) => Ok(CompressionCodec::Lz4Frame),
_ => Err(GeneralError(format!(
"Unsupported shuffle compression codec: {:?}",
writer.codec
))),
}?;
let shuffle_writer = Arc::new(ShuffleWriterExec::try_new(
Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)),
partitioning,
codec,
writer.output_data_file.clone(),
writer.output_index_file.clone(),
)?);
Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
shuffle_writer,
vec![Arc::clone(&child)],
)),
))
}
OpStruct::Expand(expand) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let mut projections = vec![];
let mut projection = vec![];
expand.project_list.iter().try_for_each(|expr| {
let expr = self.create_expr(expr, child.schema())?;
projection.push(expr);
if projection.len() == expand.num_expr_per_project as usize {
projections.push(projection.clone());
projection = vec![];
}
Ok::<(), ExecutionError>(())
})?;
assert!(
!projections.is_empty(),
"Expand should have at least one projection"
);
let datatypes = projections[0]
.iter()
.map(|expr| expr.data_type(&child.schema()))
.collect::<Result<Vec<DataType>, _>>()?;
let fields: Vec<Field> = datatypes
.iter()
.enumerate()
.map(|(idx, dt)| Field::new(format!("col_{}", idx), dt.clone(), true))
.collect();
let schema = Arc::new(Schema::new(fields));
// `Expand` operator keeps the input batch and expands it to multiple output
// batches. However, `ScanExec` will reuse input arrays for the next
// input batch. Therefore, we need to copy the input batch to avoid
// the data corruption. Note that we only need to copy the input batch
// if the child operator is `ScanExec`, because other operators after `ScanExec`
// will create new arrays for the output batch.
let input = if can_reuse_input_batch(&child.native_plan) {
Arc::new(CopyExec::new(
Arc::clone(&child.native_plan),
CopyMode::UnpackOrDeepCopy,
))
} else {
Arc::clone(&child.native_plan)
};
let expand = Arc::new(ExpandExec::new(projections, input, schema));
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])),
))
}
OpStruct::SortMergeJoin(join) => {
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&join.left_join_keys,
&join.right_join_keys,
join.join_type,
&join.condition,
partition_count,
)?;
let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
let sort_expr = self
.create_sort_expr(sort_option, join_params.left.schema())
.unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
}
})
.collect();
let join = Arc::new(SortMergeJoinExec::try_new(
Arc::clone(&join_params.left.native_plan),
Arc::clone(&join_params.right.native_plan),
join_params.join_on,
join_params.join_filter,
join_params.join_type,
sort_options,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);
if join.filter.is_some() {
// SMJ with join filter produces lots of tiny batches
let coalesce_batches: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(
Arc::<SortMergeJoinExec>::clone(&join),
self.session_ctx
.state()
.config_options()
.execution
.batch_size,
));
Ok((
scans,
Arc::new(SparkPlan::new_with_additional(
spark_plan.plan_id,
coalesce_batches,
vec![
Arc::clone(&join_params.left),
Arc::clone(&join_params.right),
],
vec![join],
)),
))
} else {
Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
join,
vec![
Arc::clone(&join_params.left),
Arc::clone(&join_params.right),
],
)),
))
}
}
OpStruct::HashJoin(join) => {
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&join.left_join_keys,
&join.right_join_keys,
join.join_type,
&join.condition,
partition_count,
)?;
// HashJoinExec may cache the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch. We also need to unpack dictionary arrays, because the join operators
// do not support them.
let left = Self::wrap_in_copy_exec(Arc::clone(&join_params.left.native_plan));
let right = Self::wrap_in_copy_exec(Arc::clone(&join_params.right.native_plan));
let hash_join = Arc::new(HashJoinExec::try_new(
left,
right,
join_params.join_on,
join_params.join_filter,
&join_params.join_type,
None,
PartitionMode::Partitioned,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);
// If the hash join is build right, we need to swap the left and right
if join.build_side == BuildSide::BuildLeft as i32 {
Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
hash_join,
vec![join_params.left, join_params.right],
)),
))
} else {
let swapped_hash_join =
hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?;
let mut additional_native_plans = vec![];
if swapped_hash_join.as_any().is::<ProjectionExec>() {
// a projection was added to the hash join
additional_native_plans.push(Arc::clone(swapped_hash_join.children()[0]));
}
Ok((
scans,
Arc::new(SparkPlan::new_with_additional(
spark_plan.plan_id,
swapped_hash_join,
vec![join_params.left, join_params.right],
additional_native_plans,
)),
))
}
}
OpStruct::Window(wnd) => {
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let input_schema = child.schema();
let sort_exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = wnd
.order_by_list
.iter()
.map(|expr| self.create_sort_expr(expr, Arc::clone(&input_schema)))
.collect();
let partition_exprs: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = wnd
.partition_by_list
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
.collect();
let sort_exprs = &sort_exprs?;
let partition_exprs = &partition_exprs?;
let window_expr: Result<Vec<Arc<dyn WindowExpr>>, ExecutionError> = wnd
.window_expr
.iter()
.map(|expr| {
self.create_window_expr(
expr,
Arc::clone(&input_schema),
partition_exprs,
sort_exprs,
)
})
.collect();
let window_agg = Arc::new(BoundedWindowAggExec::try_new(
window_expr?,
Arc::clone(&child.native_plan),
InputOrderMode::Sorted,
!partition_exprs.is_empty(),
)?);
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])),
))
}
}
}
#[allow(clippy::too_many_arguments)]
fn parse_join_parameters(
&self,
inputs: &mut Vec<Arc<GlobalRef>>,
children: &[Operator],
left_join_keys: &[Expr],
right_join_keys: &[Expr],
join_type: i32,
condition: &Option<Expr>,
partition_count: usize,
) -> Result<(JoinParameters, Vec<ScanExec>), ExecutionError> {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?;
left_scans.append(&mut right_scans);
let left_join_exprs: Vec<_> = left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs: Vec<_> = right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;
let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();
let join_type = match join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(GeneralError(format!(
"Unsupported join type: {:?}",
join_type
)));
}
};
// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = condition {
let left_schema = left.schema();
let right_schema = right.schema();
let left_fields = left_schema.fields();
let right_fields = right_schema.fields();
let all_fields: Vec<_> = left_fields
.into_iter()
.chain(right_fields)
.cloned()
.collect();
let full_schema = Arc::new(Schema::new(all_fields));
// Because we cast dictionary array to array in scan operator,
// we need to change dictionary type to data type for join filter expression.
let fields: Vec<_> = full_schema
.fields()
.iter()
.map(|f| match f.data_type() {
DataType::Dictionary(_, val_type) => Arc::new(Field::new(
f.name(),
val_type.as_ref().clone(),
f.is_nullable(),
)),
_ => Arc::clone(f),
})
.collect();
let full_schema = Arc::new(Schema::new(fields));
let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) =
expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?;
let column_indices = JoinFilter::build_column_indices(
left_field_indices.clone(),
right_field_indices.clone(),
);
let filter_fields: Vec<Field> = left_field_indices
.clone()
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.clone()
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
// Because we cast dictionary array to array in scan operator,
// we need to change dictionary type to data type for join filter expression.
.map(|f| match f.data_type() {
DataType::Dictionary(_, val_type) => {
Field::new(f.name(), val_type.as_ref().clone(), f.is_nullable())
}
_ => f.clone(),
})
.collect_vec();
let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());
// Rewrite the physical expression to use the new column indices.
// DataFusion's join filter is bound to intermediate schema which contains
// only the fields used in the filter expression. But the Spark's join filter
// expression is bound to the full schema. We need to rewrite the physical
// expression to use the new column indices.
let rewritten_physical_expr = rewrite_physical_expr(
physical_expr,
left_schema.fields.len(),
right_schema.fields.len(),
&left_field_indices,
&right_field_indices,
)?;
Some(JoinFilter::new(
rewritten_physical_expr,
column_indices,
filter_schema.into(),
))
} else {
None
};
Ok((
JoinParameters {
left: Arc::clone(&left),
right: Arc::clone(&right),
join_on,
join_type,
join_filter,
},
left_scans,
))
}
/// Wrap an ExecutionPlan in a CopyExec, which will unpack any dictionary-encoded arrays
/// and make a deep copy of other arrays if the plan re-uses batches.
fn wrap_in_copy_exec(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
if can_reuse_input_batch(&plan) {
Arc::new(CopyExec::new(plan, CopyMode::UnpackOrDeepCopy))
} else {
Arc::new(CopyExec::new(plan, CopyMode::UnpackOrClone))
}
}
/// Create a DataFusion physical aggregate expression from Spark physical aggregate expression
fn create_agg_expr(
&self,
spark_expr: &AggExpr,
schema: SchemaRef,
) -> Result<AggregateFunctionExpr, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
AggExprStruct::Count(expr) => {
assert!(!expr.children.is_empty());
// Using `count_udaf` from Comet is exceptionally slow for some reason, so
// as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))`
// https://github.com/apache/datafusion-comet/issues/744
let children = expr
.children
.iter()
.map(|child| self.create_expr(child, Arc::clone(&schema)))
.collect::<Result<Vec<_>, _>>()?;
// create `IS NOT NULL expr` and join them with `AND` if there are multiple
let not_null_expr: Arc<dyn PhysicalExpr> = children.iter().skip(1).fold(
Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc<dyn PhysicalExpr>,
|acc, child| {
Arc::new(BinaryExpr::new(
acc,
DataFusionOperator::And,
Arc::new(IsNotNullExpr::new(Arc::clone(child))),
))
},
);
let child = Arc::new(IfExpr::new(
not_null_expr,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
));
AggregateExprBuilder::new(sum_udaf(), vec![child])
.schema(schema)
.alias("count")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Min(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
AggregateExprBuilder::new(min_udaf(), vec![child])
.schema(schema)
.alias("min")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Max(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let child = Arc::new(CastExpr::new(child, datatype.clone(), None));
AggregateExprBuilder::new(max_udaf(), vec![child])
.schema(schema)
.alias("max")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
AggExprStruct::Sum(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?);
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of SUM if necessary, we should not expect
// a cast failure since it should have already been checked at Spark side
let child =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
AggregateExprBuilder::new(sum_udaf(), vec![child])
}
};
builder
.schema(schema)
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Avg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of AVG if the result data type is different
// from the input type, e.g. AVG(Int32). We should not expect a cast
// failure since it should have already been checked at Spark side.
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
builder
.schema(schema)
.alias("avg")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::First(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(FirstValue::new());
AggregateExprBuilder::new(Arc::new(func), vec![child])
.schema(schema)
.alias("first")
.with_ignore_nulls(expr.ignore_nulls)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Last(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(LastValue::new());
AggregateExprBuilder::new(Arc::new(func), vec![child])
.schema(schema)
.alias("last")
.with_ignore_nulls(expr.ignore_nulls)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitAndAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_and_udaf(), vec![child])
.schema(schema)
.alias("bit_and")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitOrAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_or_udaf(), vec![child])
.schema(schema)
.alias("bit_or")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::BitXorAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
AggregateExprBuilder::new(bit_xor_udaf(), vec![child])
.schema(schema)
.alias("bit_xor")
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
AggExprStruct::Covariance(expr) => {
let child1 =
self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?;
let child2 =
self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Covariance::new(
"covariance",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr(
"covariance",
schema,
vec![child1, child2],
func,
)
}
1 => {
let func = AggregateUDF::new_from_impl(Covariance::new(
"covariance_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr(
"covariance_pop",
schema,
vec![child1, child2],
func,
)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for Variance",
stats_type
))),
}
}
AggExprStruct::Variance(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Variance::new(
"variance",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("variance", schema, vec![child], func)
}
1 => {
let func = AggregateUDF::new_from_impl(Variance::new(
"variance_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("variance_pop", schema, vec![child], func)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for Variance",
stats_type
))),
}
}
AggExprStruct::Stddev(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
match expr.stats_type {
0 => {
let func = AggregateUDF::new_from_impl(Stddev::new(
"stddev",
datatype,
StatsType::Sample,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("stddev", schema, vec![child], func)
}
1 => {
let func = AggregateUDF::new_from_impl(Stddev::new(
"stddev_pop",
datatype,
StatsType::Population,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("stddev_pop", schema, vec![child], func)
}
stats_type => Err(GeneralError(format!(
"Unknown StatisticsType {:?} for stddev",
stats_type
))),
}
}
AggExprStruct::Correlation(expr) => {
let child1 =
self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?;
let child2 =
self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(Correlation::new(
"correlation",
datatype,
expr.null_on_divide_by_zero,
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::BloomFilterAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let num_items =
self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?;
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&num_items),
Arc::clone(&num_bits),
datatype,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
}
}
/// Create a DataFusion windows physical expression from Spark physical expression
fn create_window_expr<'a>(
&'a self,
spark_expr: &'a spark_operator::WindowExpr,
input_schema: SchemaRef,
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let window_func_name: String;
let window_args: Vec<Arc<dyn PhysicalExpr>>;
if let Some(func) = &spark_expr.built_in_window_function {
match &func.expr_struct {
Some(ExprStruct::ScalarFunc(f)) => {
window_func_name = f.func.clone();
window_args = f
.args
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
.collect::<Result<Vec<_>, ExecutionError>>()?;
}
other => {
return Err(GeneralError(format!(
"{other:?} not supported for window function"
)))
}
};
} else if let Some(agg_func) = &spark_expr.agg_func {
let result = self.process_agg_func(agg_func, Arc::clone(&input_schema))?;
window_func_name = result.0;
window_args = result.1;
} else {
return Err(GeneralError(
"Both func and agg_func are not set".to_string(),
));
}
let window_func = match self.find_df_window_function(&window_func_name) {
Some(f) => f,
_ => {
return Err(GeneralError(format!(
"{window_func_name} not supported for window function"
)))
}
};
let spark_window_frame = match spark_expr
.spec
.as_ref()
.and_then(|inner| inner.frame_specification.as_ref())
{
Some(frame) => frame,
_ => {
return Err(ExecutionError::DeserializeError(
"Cannot deserialize window frame".to_string(),
))
}
};
let units = match spark_window_frame.frame_type() {
WindowFrameType::Rows => WindowFrameUnits::Rows,
WindowFrameType::Range => WindowFrameUnits::Range,
};
let lower_bound: WindowFrameBound = match spark_window_frame
.lower_bound
.as_ref()
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.abs();
match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(
Some(offset_value as u64),
)),
WindowFrameUnits::Range => {
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
}
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};
let upper_bound: WindowFrameBound = match spark_window_frame
.upper_bound
.as_ref()
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(None))
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::Following(offset) => match units {
WindowFrameUnits::Rows => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
WindowFrameUnits::Range => {
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
}
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::Int64(None)),
WindowFrameUnits::Groups => {
return Err(GeneralError(
"WindowFrameUnits::Groups is not supported.".to_string(),
));
}
},
};
let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);
datafusion::physical_plan::windows::create_window_expr(
&window_func,
window_func_name,
&window_args,
partition_by,
&LexOrdering::new(sort_exprs.to_vec()),
window_frame.into(),
input_schema.as_ref(),
false, // TODO: Ignore nulls
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
fn process_agg_func(
&self,
agg_func: &AggExpr,
schema: SchemaRef,
) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
match &agg_func.expr_struct {
Some(AggExprStruct::Count(expr)) => {
let children = expr
.children
.iter()
.map(|child| self.create_expr(child, Arc::clone(&schema)))
.collect::<Result<Vec<_>, _>>()?;
Ok(("count".to_string(), children))
}
Some(AggExprStruct::Min(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok(("min".to_string(), vec![child]))
}
Some(AggExprStruct::Max(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
Ok(("max".to_string(), vec![child]))
}
Some(AggExprStruct::Sum(expr)) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let arrow_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let datatype = child.data_type(&schema)?;
let child = if datatype != arrow_type {
Arc::new(CastExpr::new(child, arrow_type.clone(), None))
} else {
child
};
Ok(("sum".to_string(), vec![child]))
}
other => Err(GeneralError(format!(
"{other:?} not supported for window function"
))),
}
}
/// Find DataFusion's built-in window function by name.
fn find_df_window_function(&self, name: &str) -> Option<WindowFunctionDefinition> {
let registry = &self.session_ctx.state();
registry
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
}
/// Create a DataFusion physical partitioning from Spark physical partitioning
fn create_partitioning(
&self,
spark_partitioning: &SparkPartitioning,
input_schema: SchemaRef,
) -> Result<Partitioning, ExecutionError> {
match spark_partitioning.partitioning_struct.as_ref().unwrap() {
PartitioningStruct::HashPartition(hash_partition) => {
let exprs: PartitionPhyExprResult = hash_partition
.hash_expression
.iter()
.map(|x| self.create_expr(x, Arc::clone(&input_schema)))
.collect();
Ok(Partitioning::Hash(
exprs?,
hash_partition.num_partitions as usize,
))
}
PartitioningStruct::SinglePartition(_) => Ok(Partitioning::UnknownPartitioning(1)),
}
}
fn create_scalar_function_expr(
&self,
expr: &ScalarFunc,
input_schema: SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let args = expr
.args
.iter()
.map(|x| self.create_expr(x, Arc::clone(&input_schema)))
.collect::<Result<Vec<_>, _>>()?;
let fun_name = &expr.func;
let input_expr_types = args
.iter()
.map(|x| x.data_type(input_schema.as_ref()))
.collect::<Result<Vec<_>, _>>()?;
let (data_type, coerced_input_types) =
match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => (t, input_expr_types.clone()),
None => {
let fun_name = match fun_name.as_ref() {
"read_side_padding" => "rpad", // use the same return type as rpad
other => other,
};
let func = self.session_ctx.udf(fun_name)?;
let coerced_types = func
.coerce_types(&input_expr_types)
.unwrap_or_else(|_| input_expr_types.clone());
// TODO this should try and find scalar
let arguments = args
.iter()
.map(|e| {
e.as_ref()
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value())
})
.collect::<Vec<_>>();
let nullables = arguments.iter().map(|_| true).collect::<Vec<_>>();
let args = ReturnTypeArgs {
arg_types: &coerced_types,
scalar_arguments: &arguments,
nullables: &nullables,
};
let data_type = func
.inner()
.return_type_from_args(args)?
.return_type()
.clone();
(data_type, coerced_types)
}
};
let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;
let args = args
.into_iter()
.zip(input_expr_types.into_iter().zip(coerced_input_types))
.map(|(expr, (from_type, to_type))| {
if from_type != to_type {
Arc::new(CastExpr::new(
expr,
to_type,
Some(CastOptions {
safe: false,
..Default::default()
}),
))
} else {
expr
}
})
.collect::<Vec<_>>();
let scalar_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
fun_name,
fun_expr,
args.to_vec(),
data_type,
));
Ok(scalar_expr)
}
fn create_aggr_func_expr(
name: &str,
schema: SchemaRef,
children: Vec<Arc<dyn PhysicalExpr>>,
func: AggregateUDF,
) -> Result<AggregateFunctionExpr, ExecutionError> {
AggregateExprBuilder::new(Arc::new(func), children)
.schema(schema)
.alias(name)
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.map_err(|e| e.into())
}
}
impl From<DataFusionError> for ExecutionError {
fn from(value: DataFusionError) -> Self {
ExecutionError::DataFusionError(value.message().to_string())
}
}
impl From<ExecutionError> for DataFusionError {
fn from(value: ExecutionError) -> Self {
DataFusionError::Execution(value.to_string())
}
}
impl From<ExpressionError> for DataFusionError {
fn from(value: ExpressionError) -> Self {
DataFusionError::Execution(value.to_string())
}
}
/// Returns true if given operator can return input array as output array without
/// modification. This is used to determine if we need to copy the input batch to avoid
/// data corruption from reusing the input batch.
fn can_reuse_input_batch(op: &Arc<dyn ExecutionPlan>) -> bool {
if op.as_any().is::<ProjectionExec>() || op.as_any().is::<LocalLimitExec>() {
can_reuse_input_batch(op.children()[0])
} else {
op.as_any().is::<ScanExec>()
}
}
/// Collects the indices of the columns in the input schema that are used in the expression
/// and returns them as a pair of vectors, one for the left side and one for the right side.
fn expr_to_columns(
expr: &Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
) -> Result<(Vec<usize>, Vec<usize>), ExecutionError> {
let mut left_field_indices: Vec<usize> = vec![];
let mut right_field_indices: Vec<usize> = vec![];
expr.apply(&mut |expr: &Arc<dyn PhysicalExpr>| {
Ok({
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if column.index() > left_field_len + right_field_len {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
} else if column.index() < left_field_len {
left_field_indices.push(column.index());
} else {
right_field_indices.push(column.index() - left_field_len);
}
}
TreeNodeRecursion::Continue
})
})?;
left_field_indices.sort();
right_field_indices.sort();
Ok((left_field_indices, right_field_indices))
}
/// A physical join filter rewritter which rewrites the column indices in the expression
/// to use the new column indices. See `rewrite_physical_expr`.
struct JoinFilterRewriter<'a> {
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
}
impl JoinFilterRewriter<'_> {
fn new<'a>(
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
) -> JoinFilterRewriter<'a> {
JoinFilterRewriter {
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
}
}
}
impl TreeNodeRewriter for JoinFilterRewriter<'_> {
type Node = Arc<dyn PhysicalExpr>;
fn f_down(&mut self, node: Self::Node) -> datafusion::common::Result<Transformed<Self::Node>> {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
if column.index() < self.left_field_len {
// left side
let new_index = self
.left_field_indices
.iter()
.position(|&x| x == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in left field indices",
column.index()
))
})?;
Ok(Transformed::yes(Arc::new(Column::new(
column.name(),
new_index,
))))
} else if column.index() < self.left_field_len + self.right_field_len {
// right side
let new_index = self
.right_field_indices
.iter()
.position(|&x| x + self.left_field_len == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in right field indices",
column.index()
))
})?;
Ok(Transformed::yes(Arc::new(Column::new(
column.name(),
new_index + self.left_field_indices.len(),
))))
} else {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
}
} else {
Ok(Transformed::no(node))
}
}
}
/// Rewrites the physical expression to use the new column indices.
/// This is necessary when the physical expression is used in a join filter, as the column
/// indices are different from the original schema.
fn rewrite_physical_expr(
expr: Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
left_field_indices: &[usize],
right_field_indices: &[usize],
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let mut rewriter = JoinFilterRewriter::new(
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
);
Ok(expr.rewrite(&mut rewriter).data()?)
}
fn from_protobuf_eval_mode(value: i32) -> Result<EvalMode, prost::UnknownEnumValue> {
match spark_expression::EvalMode::try_from(value)? {
spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy),
spark_expression::EvalMode::Try => Ok(EvalMode::Try),
spark_expression::EvalMode::Ansi => Ok(EvalMode::Ansi),
}
}
fn convert_spark_types_to_arrow_schema(
spark_types: &[spark_operator::SparkStructField],
) -> SchemaRef {
let arrow_fields = spark_types
.iter()
.map(|spark_type| {
Field::new(
String::clone(&spark_type.name),
to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
spark_type.nullable,
)
})
.collect_vec();
let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
arrow_schema
}
/// Create CASE WHEN expression and add casting as needed
fn create_case_expr(
when_then_pairs: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let then_types: Vec<DataType> = when_then_pairs
.iter()
.map(|x| x.1.data_type(input_schema))
.collect::<Result<Vec<_>, _>>()?;
let else_type: Option<DataType> = else_expr
.as_ref()
.map(|x| Arc::clone(x).data_type(input_schema))
.transpose()?
.or(Some(DataType::Null));
if let Some(coerce_type) = get_coerce_type_for_case_expression(&then_types, else_type.as_ref())
{
let cast_options = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
let when_then_pairs = when_then_pairs
.iter()
.map(|x| {
let t: Arc<dyn PhysicalExpr> = Arc::new(Cast::new(
Arc::clone(&x.1),
coerce_type.clone(),
cast_options.clone(),
));
(Arc::clone(&x.0), t)
})
.collect::<Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>>();
let else_phy_expr: Option<Arc<dyn PhysicalExpr>> = else_expr.clone().map(|x| {
Arc::new(Cast::new(x, coerce_type.clone(), cast_options.clone()))
as Arc<dyn PhysicalExpr>
});
Ok(Arc::new(CaseExpr::try_new(
None,
when_then_pairs,
else_phy_expr,
)?))
} else {
Ok(Arc::new(CaseExpr::try_new(
None,
when_then_pairs,
else_expr.clone(),
)?))
}
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
use futures::{poll, StreamExt};
use arrow::array::{Array, DictionaryArray, Int32Array, StringArray};
use arrow::datatypes::DataType;
use datafusion::logical_expr::ScalarUDF;
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
use tokio::sync::mpsc;
use crate::execution::{operators::InputBatch, planner::PhysicalPlanner};
use crate::execution::operators::ExecutionError;
use datafusion_comet_proto::spark_expression::expr::ExprStruct;
use datafusion_comet_proto::{
spark_expression::expr::ExprStruct::*,
spark_expression::Expr,
spark_expression::{self, literal},
spark_operator,
spark_operator::{operator::OpStruct, Operator},
};
#[test]
fn test_unpack_dictionary_primitive() {
let op_scan = Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
}],
source: "".to_string(),
})),
};
let op = create_filter(op_scan, 3);
let planner = PhysicalPlanner::default();
let row_count = 100;
// Create a dictionary array with 100 values, and use it as input to the execution.
let keys = Int32Array::new((0..(row_count as i32)).map(|n| n % 4).collect(), None);
let values = Int32Array::from(vec![0, 1, 2, 3]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap();
scans[0].set_input_batch(input_batch);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
let runtime = tokio::runtime::Runtime::new().unwrap();
let (tx, mut rx) = mpsc::channel(1);
// Separate thread to send the EOF signal once we've processed the only input batch
runtime.spawn(async move {
// Create a dictionary array with 100 values, and use it as input to the execution.
let keys = Int32Array::new((0..(row_count as i32)).map(|n| n % 4).collect(), None);
let values = Int32Array::from(vec![0, 1, 2, 3]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch1 = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let input_batch2 = InputBatch::EOF;
let batches = vec![input_batch1, input_batch2];
for batch in batches.into_iter() {
tx.send(batch).await.unwrap();
}
});
runtime.block_on(async move {
loop {
let batch = rx.recv().await.unwrap();
scans[0].set_input_batch(batch);
match poll!(stream.next()) {
Poll::Ready(Some(batch)) => {
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
let batch = batch.unwrap();
assert_eq!(batch.num_rows(), row_count / 4);
// dictionary should be unpacked
assert!(matches!(batch.column(0).data_type(), DataType::Int32));
}
Poll::Ready(None) => {
break;
}
_ => {}
}
}
});
}
const STRING_TYPE_ID: i32 = 7;
#[test]
fn test_unpack_dictionary_string() {
let op_scan = Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![spark_expression::DataType {
type_id: STRING_TYPE_ID, // String
type_info: None,
}],
source: "".to_string(),
})),
};
let lit = spark_expression::Literal {
value: Some(literal::Value::StringVal("foo".to_string())),
datatype: Some(spark_expression::DataType {
type_id: STRING_TYPE_ID,
type_info: None,
}),
is_null: false,
};
let op = create_filter_literal(op_scan, STRING_TYPE_ID, lit);
let planner = PhysicalPlanner::default();
let row_count = 100;
let keys = Int32Array::new((0..(row_count as i32)).map(|n| n % 4).collect(), None);
let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap();
// Scan's schema is determined by the input batch, so we need to set it before execution.
scans[0].set_input_batch(input_batch);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
let runtime = tokio::runtime::Runtime::new().unwrap();
let (tx, mut rx) = mpsc::channel(1);
// Separate thread to send the EOF signal once we've processed the only input batch
runtime.spawn(async move {
// Create a dictionary array with 100 values, and use it as input to the execution.
let keys = Int32Array::new((0..(row_count as i32)).map(|n| n % 4).collect(), None);
let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch1 = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let input_batch2 = InputBatch::EOF;
let batches = vec![input_batch1, input_batch2];
for batch in batches.into_iter() {
tx.send(batch).await.unwrap();
}
});
runtime.block_on(async move {
loop {
let batch = rx.recv().await.unwrap();
scans[0].set_input_batch(batch);
match poll!(stream.next()) {
Poll::Ready(Some(batch)) => {
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
let batch = batch.unwrap();
assert_eq!(batch.num_rows(), row_count / 4);
// string/binary should still be packed with dictionary
assert!(matches!(
batch.column(0).data_type(),
DataType::Dictionary(_, _)
));
}
Poll::Ready(None) => {
break;
}
_ => {}
}
}
});
}
#[tokio::test()]
#[allow(clippy::field_reassign_with_default)]
async fn to_datafusion_filter() {
let op_scan = create_scan();
let op = create_filter(op_scan, 0);
let planner = PhysicalPlanner::default();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap();
let scan = &mut scans[0];
scan.set_input_batch(InputBatch::EOF);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let stream = datafusion_plan
.native_plan
.execute(0, Arc::clone(&task_ctx))
.unwrap();
let output = collect(stream).await.unwrap();
assert!(output.is_empty());
}
#[tokio::test()]
async fn from_datafusion_error_to_comet() {
let err_msg = "exec error";
let err = datafusion::common::DataFusionError::Execution(err_msg.to_string());
let comet_err: ExecutionError = err.into();
assert_eq!(comet_err.to_string(), "Error from DataFusion: exec error.");
}
// Creates a filter operator which takes an `Int32Array` and selects rows that are equal to
// `value`.
fn create_filter(child_op: spark_operator::Operator, value: i32) -> spark_operator::Operator {
let lit = spark_expression::Literal {
value: Some(literal::Value::IntVal(value)),
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
is_null: false,
};
create_filter_literal(child_op, 3, lit)
}
fn create_filter_literal(
child_op: spark_operator::Operator,
type_id: i32,
lit: spark_expression::Literal,
) -> spark_operator::Operator {
let left = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 0,
datatype: Some(spark_expression::DataType {
type_id,
type_info: None,
}),
})),
};
let right = spark_expression::Expr {
expr_struct: Some(Literal(lit)),
};
let expr = spark_expression::Expr {
expr_struct: Some(Eq(Box::new(spark_expression::BinaryExpr {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
}))),
};
Operator {
plan_id: 0,
children: vec![child_op],
op_struct: Some(OpStruct::Filter(spark_operator::Filter {
predicate: Some(expr),
use_datafusion_filter: false,
})),
}
}
#[test]
fn spark_plan_metrics_filter() {
let op_scan = create_scan();
let op = create_filter(op_scan, 0);
let planner = PhysicalPlanner::default();
let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap();
assert_eq!("CometFilterExec", filter_exec.native_plan.name());
assert_eq!(1, filter_exec.children.len());
assert_eq!(0, filter_exec.additional_native_plans.len());
}
#[test]
fn spark_plan_metrics_hash_join() {
let op_scan = create_scan();
let op_join = Operator {
plan_id: 0,
children: vec![op_scan.clone(), op_scan.clone()],
op_struct: Some(OpStruct::HashJoin(spark_operator::HashJoin {
left_join_keys: vec![create_bound_reference(0)],
right_join_keys: vec![create_bound_reference(0)],
join_type: 0,
condition: None,
build_side: 0,
})),
};
let planner = PhysicalPlanner::default();
let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap();
assert_eq!("HashJoinExec", hash_join_exec.native_plan.name());
assert_eq!(2, hash_join_exec.children.len());
assert_eq!("ScanExec", hash_join_exec.children[0].native_plan.name());
assert_eq!("ScanExec", hash_join_exec.children[1].native_plan.name());
}
fn create_bound_reference(index: i32) -> Expr {
Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index,
datatype: Some(create_proto_datatype()),
})),
}
}
fn create_scan() -> Operator {
Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![create_proto_datatype()],
source: "".to_string(),
})),
}
}
fn create_proto_datatype() -> spark_expression::DataType {
spark_expression::DataType {
type_id: 3,
type_info: None,
}
}
#[test]
fn test_create_array() {
let session_ctx = SessionContext::new();
session_ctx.register_udf(ScalarUDF::from(
datafusion_functions_nested::make_array::MakeArray::new(),
));
let task_ctx = session_ctx.task_ctx();
let planner = PhysicalPlanner::new(Arc::from(session_ctx));
// Create a plan for
// ProjectionExec: expr=[make_array(col_0@0) as col_0]
// ScanExec: source=[CometScan parquet (unknown)], schema=[col_0: Int32]
let op_scan = Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
],
source: "".to_string(),
})),
};
let array_col = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 0,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};
let array_col_1 = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 1,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};
let projection = Operator {
children: vec![op_scan],
plan_id: 0,
op_struct: Some(OpStruct::Projection(spark_operator::Projection {
project_list: vec![spark_expression::Expr {
expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
func: "make_array".to_string(),
args: vec![array_col, array_col_1],
return_type: None,
})),
}],
})),
};
let a = Int32Array::from(vec![0, 3]);
let b = Int32Array::from(vec![1, 4]);
let c = Int32Array::from(vec![2, 5]);
let input_batch = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 2);
let (mut scans, datafusion_plan) =
planner.create_plan(&projection, &mut vec![], 1).unwrap();
scans[0].set_input_batch(input_batch);
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
let runtime = tokio::runtime::Runtime::new().unwrap();
let (tx, mut rx) = mpsc::channel(1);
// Separate thread to send the EOF signal once we've processed the only input batch
runtime.spawn(async move {
let a = Int32Array::from(vec![0, 3]);
let b = Int32Array::from(vec![1, 4]);
let c = Int32Array::from(vec![2, 5]);
let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 2);
let input_batch2 = InputBatch::EOF;
let batches = vec![input_batch1, input_batch2];
for batch in batches.into_iter() {
tx.send(batch).await.unwrap();
}
});
runtime.block_on(async move {
loop {
let batch = rx.recv().await.unwrap();
scans[0].set_input_batch(batch);
match poll!(stream.next()) {
Poll::Ready(Some(batch)) => {
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
let batch = batch.unwrap();
assert_eq!(batch.num_rows(), 2);
let expected = [
"+--------+",
"| col_0 |",
"+--------+",
"| [0, 1] |",
"| [3, 4] |",
"+--------+",
];
assert_batches_eq!(expected, &[batch]);
}
Poll::Ready(None) => {
break;
}
_ => {}
}
}
});
}
#[test]
fn test_array_repeat() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let planner = PhysicalPlanner::new(Arc::from(session_ctx));
// Mock scan operator with 3 INT32 columns
let op_scan = Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
],
source: "".to_string(),
})),
};
// Mock expression to read a INT32 column with position 0
let array_col = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 0,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};
// Mock expression to read a INT32 column with position 1
let array_col_1 = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 1,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};
// Make a projection operator with array_repeat(array_col, array_col_1)
let projection = Operator {
children: vec![op_scan],
plan_id: 0,
op_struct: Some(OpStruct::Projection(spark_operator::Projection {
project_list: vec![spark_expression::Expr {
expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
func: "array_repeat".to_string(),
args: vec![array_col, array_col_1],
return_type: None,
})),
}],
})),
};
// Create a physical plan
let (mut scans, datafusion_plan) =
planner.create_plan(&projection, &mut vec![], 1).unwrap();
// Feed the data into plan
//scans[0].set_input_batch(input_batch);
// Start executing the plan in a separate thread
// The plan waits for incoming batches and emitting result as input comes
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
let runtime = tokio::runtime::Runtime::new().unwrap();
// create async channel
let (tx, mut rx) = mpsc::channel(1);
// Send data as input to the plan being executed in a separate thread
runtime.spawn(async move {
// create data batch
// 0, 1, 2
// 3, 4, 5
// 6, null, null
let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]);
let b = Int32Array::from(vec![Some(1), Some(4), None]);
let c = Int32Array::from(vec![Some(2), Some(5), None]);
let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3);
let input_batch2 = InputBatch::EOF;
let batches = vec![input_batch1, input_batch2];
for batch in batches.into_iter() {
tx.send(batch).await.unwrap();
}
});
// Wait for the plan to finish executing and assert the result
runtime.block_on(async move {
loop {
let batch = rx.recv().await.unwrap();
scans[0].set_input_batch(batch);
match poll!(stream.next()) {
Poll::Ready(Some(batch)) => {
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
let batch = batch.unwrap();
let expected = [
"+--------------+",
"| col_0 |",
"+--------------+",
"| [0] |",
"| [3, 3, 3, 3] |",
"| |",
"+--------------+",
];
assert_batches_eq!(expected, &[batch]);
}
Poll::Ready(None) => {
break;
}
_ => {}
}
}
});
}
}