datafusion/substrait/src/logical_plan/producer.rs (1,548 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use datafusion::logical_expr::Like;
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
logical_expr::{WindowFrame, WindowFrameBound},
prelude::{JoinType, SessionContext},
scalar::ScalarValue,
};
use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
Alias, BinaryExpr, Case, Cast, InList, ScalarFunction as DFScalarFunction, Sort,
WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::{
proto::{
aggregate_function::AggregationInvocation,
aggregate_rel::{Grouping, Measure},
expression::{
field_reference::ReferenceType,
if_then::IfClause,
literal::{Decimal, LiteralType},
mask_expression::{StructItem, StructSelect},
reference_segment,
window_function::bound as SubstraitBound,
window_function::bound::Kind as BoundKind,
window_function::Bound,
FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType,
ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction,
},
extensions::{
self,
simple_extension_declaration::{ExtensionFunction, MappingType},
},
function_argument::ArgType,
join_rel, plan_rel, r#type,
read_rel::{NamedTable, ReadType},
rel::RelType,
set_rel,
sort_field::{SortDirection, SortKind},
AggregateFunction, AggregateRel, AggregationPhase, Expression, ExtensionLeafRel,
ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument,
JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SetRel,
SortField, SortRel,
},
version,
};
use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
};
/// Convert DataFusion LogicalPlan to Substrait Plan
pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result<Box<Plan>> {
// Parse relation nodes
let mut extension_info: (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
) = (vec![], HashMap::new());
// Generate PlanRel(s)
// Note: Only 1 relation tree is currently supported
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?),
names: plan.schema().field_names(),
})),
}];
let (function_extensions, _) = extension_info;
// Return parsed plan
Ok(Box::new(Plan {
version: Some(version::version_with_producer("datafusion")),
extension_uris: vec![],
extensions: function_extensions,
relations: plan_rels,
advanced_extensions: None,
expected_type_urls: vec![],
}))
}
/// Convert DataFusion LogicalPlan to Substrait Rel
pub fn to_substrait_rel(
plan: &LogicalPlan,
ctx: &SessionContext,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Box<Rel>> {
match plan {
LogicalPlan::TableScan(scan) => {
let projection = scan.projection.as_ref().map(|p| {
p.iter()
.map(|i| StructItem {
field: *i as i32,
child: None,
})
.collect()
});
let projection = projection.map(|struct_items| MaskExpression {
select: Some(StructSelect { struct_items }),
maintain_singular_struct: false,
});
Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names: scan
.source
.schema()
.fields()
.iter()
.map(|f| f.name().to_owned())
.collect(),
r#struct: None,
}),
filter: None,
best_effort_filter: None,
projection,
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})),
}))),
}))
}
LogicalPlan::Projection(p) => {
let expressions = p
.expr
.iter()
.map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: None,
input: Some(to_substrait_rel(p.input.as_ref(), ctx, extension_info)?),
expressions,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Filter(filter) => {
let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?;
let filter_expr = to_substrait_rex(
&filter.predicate,
filter.input.schema(),
0,
extension_info,
)?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Filter(Box::new(FilterRel {
common: None,
input: Some(input),
condition: Some(Box::new(filter_expr)),
advanced_extension: None,
}))),
}))
}
LogicalPlan::Limit(limit) => {
let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?;
let limit_fetch = limit.fetch.unwrap_or(0);
Ok(Box::new(Rel {
rel_type: Some(RelType::Fetch(Box::new(FetchRel {
common: None,
input: Some(input),
offset: limit.skip as i64,
count: limit_fetch as i64,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Sort(sort) => {
let input = to_substrait_rel(sort.input.as_ref(), ctx, extension_info)?;
let sort_fields = sort
.expr
.iter()
.map(|e| substrait_sort_field(e, sort.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
common: None,
input: Some(input),
sorts: sort_fields,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?;
// Translate aggregate expression to Substrait's groupings (repeated repeated Expression)
let grouping = agg
.group_expr
.iter()
.map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
let measures = agg
.aggr_expr
.iter()
.map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
groupings: vec![Grouping {
grouping_expressions: grouping,
}], //groupings,
measures,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Distinct(distinct) => {
// Use Substrait's AggregateRel with empty measures to represent `select distinct`
let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?;
// Get grouping keys from the input relation's number of output fields
let grouping = (0..distinct.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
groupings: vec![Grouping {
grouping_expressions: grouping,
}],
measures: vec![],
advanced_extension: None,
}))),
}))
}
LogicalPlan::Join(join) => {
let left = to_substrait_rel(join.left.as_ref(), ctx, extension_info)?;
let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?;
let join_type = to_substrait_jointype(join.join_type);
// we only support basic joins so return an error for anything not yet supported
match join.join_constraint {
JoinConstraint::On => {}
JoinConstraint::Using => {
return Err(DataFusionError::NotImplemented(
"join constraint: `using`".to_string(),
))
}
}
// parse filter if exists
let in_join_schema = join.left.schema().join(join.right.schema())?;
let join_filter = match &join.filter {
Some(filter) => Some(Box::new(to_substrait_rex(
filter,
&Arc::new(in_join_schema),
0,
extension_info,
)?)),
None => None,
};
// map the left and right columns to binary expressions in the form `l = r`
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
let eq_op = if join.null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};
let join_expr = to_substrait_join_expr(
&join.on,
eq_op,
join.left.schema(),
join.right.schema(),
extension_info,
)?
.map(Box::new);
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: join_expr,
post_join_filter: join_filter,
advanced_extension: None,
}))),
}))
}
LogicalPlan::SubqueryAlias(alias) => {
// Do nothing if encounters SubqueryAlias
// since there is no corresponding relation type in Substrait
to_substrait_rel(alias.input.as_ref(), ctx, extension_info)
}
LogicalPlan::Union(union) => {
let input_rels = union
.inputs
.iter()
.map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info))
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|ptr| *ptr)
.collect();
Ok(Box::new(Rel {
rel_type: Some(substrait::proto::rel::RelType::Set(SetRel {
common: None,
inputs: input_rels,
op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL
advanced_extension: None,
})),
}))
}
LogicalPlan::Window(window) => {
let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?;
// If the input is a Project relation, we can just append the WindowFunction expressions
// before returning
// Otherwise, wrap the input in a Project relation before appending the WindowFunction
// expressions
let mut project_rel: Box<ProjectRel> = match &input.as_ref().rel_type {
Some(RelType::Project(p)) => Box::new(*p.clone()),
_ => {
// Create Projection with field referencing all output fields in the input relation
let expressions = (0..window.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
Box::new(ProjectRel {
common: None,
input: Some(input),
expressions,
advanced_extension: None,
})
}
};
// Parse WindowFunction expression
let mut window_exprs = vec![];
for expr in &window.window_expr {
window_exprs.push(to_substrait_rex(
expr,
window.input.schema(),
0,
extension_info,
)?);
}
// Append parsed WindowFunction expressions
project_rel.expressions.extend(window_exprs);
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(project_rel)),
}))
}
LogicalPlan::Extension(extension_plan) => {
let extension_bytes = ctx
.state()
.serializer_registry()
.serialize_logical_plan(extension_plan.node.as_ref())?;
let detail = ProtoAny {
type_url: extension_plan.node.name().to_string(),
value: extension_bytes,
};
let mut inputs_rel = extension_plan
.node
.inputs()
.into_iter()
.map(|plan| to_substrait_rel(plan, ctx, extension_info))
.collect::<Result<Vec<_>>>()?;
let rel_type = match inputs_rel.len() {
0 => RelType::ExtensionLeaf(ExtensionLeafRel {
common: None,
detail: Some(detail),
}),
1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
common: None,
detail: Some(detail),
input: Some(inputs_rel.pop().unwrap()),
})),
_ => RelType::ExtensionMulti(ExtensionMultiRel {
common: None,
detail: Some(detail),
inputs: inputs_rel.into_iter().map(|r| *r).collect(),
}),
};
Ok(Box::new(Rel {
rel_type: Some(rel_type),
}))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported operator: {plan:?}"
))),
}
}
fn to_substrait_join_expr(
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
left_schema: &DFSchemaRef,
right_schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Option<Expression>> {
// Only support AND conjunction for each binary expression in join conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
// Parse left
let l = to_substrait_rex(left, left_schema, 0, extension_info)?;
// Parse right
let r = to_substrait_rex(
right,
right_schema,
left_schema.fields().len(), // offset to return the correct index
extension_info,
)?;
// AND with existing expression
exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info));
}
let join_expr: Option<Expression> =
exprs.into_iter().reduce(|acc: Expression, e: Expression| {
make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info)
});
Ok(join_expr)
}
fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
match join_type {
JoinType::Inner => join_rel::JoinType::Inner,
JoinType::Left => join_rel::JoinType::Left,
JoinType::Right => join_rel::JoinType::Right,
JoinType::Full => join_rel::JoinType::Outer,
JoinType::LeftAnti => join_rel::JoinType::Anti,
JoinType::LeftSemi => join_rel::JoinType::Semi,
JoinType::RightAnti | JoinType::RightSemi => unimplemented!(),
}
}
pub fn operator_to_name(op: Operator) -> &'static str {
match op {
Operator::Eq => "equal",
Operator::NotEq => "not_equal",
Operator::Lt => "lt",
Operator::LtEq => "lte",
Operator::Gt => "gt",
Operator::GtEq => "gte",
Operator::Plus => "add",
Operator::Minus => "subtract",
Operator::Multiply => "multiply",
Operator::Divide => "divide",
Operator::Modulo => "mod",
Operator::And => "and",
Operator::Or => "or",
Operator::IsDistinctFrom => "is_distinct_from",
Operator::IsNotDistinctFrom => "is_not_distinct_from",
Operator::RegexMatch => "regex_match",
Operator::RegexIMatch => "regex_imatch",
Operator::RegexNotMatch => "regex_not_match",
Operator::RegexNotIMatch => "regex_not_imatch",
Operator::BitwiseAnd => "bitwise_and",
Operator::BitwiseOr => "bitwise_or",
Operator::StringConcat => "str_concat",
Operator::AtArrow => "at_arrow",
Operator::ArrowAt => "arrow_at",
Operator::BitwiseXor => "bitwise_xor",
Operator::BitwiseShiftRight => "bitwise_shift_right",
Operator::BitwiseShiftLeft => "bitwise_shift_left",
}
}
#[allow(deprecated)]
pub fn to_substrait_agg_measure(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Measure> {
match expr {
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => {
let sorts = if let Some(order_by) = order_by {
order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
arguments,
sorts,
output_type: None,
invocation: match distinct {
true => AggregationInvocation::Distinct as i32,
false => AggregationInvocation::All as i32,
},
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
None => None
}
})
}
Expr::Alias(Alias{expr,..})=> {
to_substrait_agg_measure(expr, schema, extension_info)
}
_ => Err(DataFusionError::Internal(format!(
"Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}",
expr,
expr.variant_name()
))),
}
}
/// Converts sort expression to corresponding substrait `SortField`
fn to_substrait_sort_field(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<SortField> {
match expr {
Expr::Sort(sort) => {
let sort_kind = match (sort.asc, sort.nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
(false, true) => SortDirection::DescNullsFirst,
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
expr: Some(to_substrait_rex(
sort.expr.deref(),
schema,
0,
extension_info,
)?),
sort_kind: Some(SortKind::Direction(sort_kind.into())),
})
}
_ => Err(DataFusionError::Execution(
"expects to receive sort expression".to_string(),
)),
}
}
fn _register_function(
function_name: String,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> u32 {
let (function_extensions, function_set) = extension_info;
let function_name = function_name.to_lowercase();
// To prevent ambiguous references between ScalarFunctions and AggregateFunctions,
// a plan-relative identifier starting from 0 is used as the function_anchor.
// The consumer is responsible for correctly registering <function_anchor,function_name>
// mapping info stored in the extensions by the producer.
let function_anchor = match function_set.get(&function_name) {
Some(function_anchor) => {
// Function has been registered
*function_anchor
}
None => {
// Function has NOT been registered
let function_anchor = function_set.len() as u32;
function_set.insert(function_name.clone(), function_anchor);
let function_extension = ExtensionFunction {
extension_uri_reference: u32::MAX,
function_anchor,
name: function_name,
};
let simple_extension = extensions::SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
};
function_extensions.push(simple_extension);
function_anchor
}
};
// Return function anchor
function_anchor
}
/// Return Substrait scalar function with two arguments
#[allow(deprecated)]
pub fn make_binary_op_scalar_func(
lhs: &Expression,
rhs: &Expression,
op: Operator,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Expression {
let function_name = operator_to_name(op).to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![
FunctionArgument {
arg_type: Some(ArgType::Value(lhs.clone())),
},
FunctionArgument {
arg_type: Some(ArgType::Value(rhs.clone())),
},
],
output_type: None,
args: vec![],
options: vec![],
})),
}
}
/// Convert DataFusion Expr to Substrait Rex
///
/// # Arguments
///
/// * `expr` - DataFusion expression to be parse into a Substrait expression
/// * `schema` - DataFusion input schema for looking up field qualifiers
/// * `col_ref_offset` - Offset for caculating Substrait field reference indices.
/// This should only be set by caller with more than one input relations i.e. Join.
/// Substrait expects one set of indices when joining two relations.
/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right`
/// relation will have column indices from `0` to `n-1`, however, Substrait will expect
/// the `right` indices to be offset by the `left`. This means Substrait will expect to
/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example:
/// ```SELECT *
/// FROM t1
/// JOIN t2
/// ON t1.c1 = t2.c0;```
/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1]
/// the join condition should become
/// `col_ref(1) = col_ref(3 + 0)`
/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index
/// of the join key column from `right`
/// * `extension_info` - Substrait extension info. Contains registered function information
#[allow(deprecated)]
pub fn to_substrait_rex(
expr: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
match expr {
Expr::InList(InList {
expr,
list,
negated,
}) => {
let substrait_list = list
.iter()
.map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info))
.collect::<Result<Vec<Expression>>>()?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_or_list = Expression {
rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList {
value: Some(Box::new(substrait_expr)),
options: substrait_list,
}))),
};
if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_or_list)),
}],
output_type: None,
args: vec![],
options: vec![],
})),
})
} else {
Ok(substrait_or_list)
}
}
Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
})
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_expr,
&substrait_low,
Operator::Lt,
extension_info,
);
let r_expr = make_binary_op_scalar_func(
&substrait_high,
&substrait_expr,
Operator::Lt,
extension_info,
);
Ok(make_binary_op_scalar_func(
&l_expr,
&r_expr,
Operator::Or,
extension_info,
))
} else {
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_low,
&substrait_expr,
Operator::LtEq,
extension_info,
);
let r_expr = make_binary_op_scalar_func(
&substrait_expr,
&substrait_high,
Operator::LtEq,
extension_info,
);
Ok(make_binary_op_scalar_func(
&l_expr,
&r_expr,
Operator::And,
extension_info,
))
}
}
Expr::Column(col) => {
let index = schema.index_of_column(col)?;
substrait_field_ref(index + col_ref_offset)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?;
let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?;
Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
}
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => {
let mut ifs: Vec<IfClause> = vec![];
// Parse base
if let Some(e) = expr {
// Base expression exists
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?),
then: None,
});
}
// Parse `when`s
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
r#if,
schema,
col_ref_offset,
extension_info,
)?),
then: Some(to_substrait_rex(
then,
schema,
col_ref_offset,
extension_info,
)?),
});
}
// Parse outer `else`
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?)),
None => None,
};
Ok(Expression {
rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))),
})
}
Expr::Cast(Cast { expr, data_type }) => {
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(data_type)?),
input: Some(Box::new(to_substrait_rex(
expr,
schema,
col_ref_offset,
extension_info,
)?)),
failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED
},
))),
})
}
Expr::Literal(value) => to_substrait_literal(value),
Expr::Alias(Alias { expr, .. }) => {
to_substrait_rex(expr, schema, col_ref_offset, extension_info)
}
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
}) => {
// function reference
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
}
// partition by expressions
let partition_by = partition_by
.iter()
.map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info))
.collect::<Result<Vec<_>>>()?;
// order by expressions
let order_by = order_by
.iter()
.map(|e| substrait_sort_field(e, schema, extension_info))
.collect::<Result<Vec<_>>>()?;
// window frame
let bounds = to_substrait_bounds(window_frame)?;
Ok(make_substrait_window_function(
function_anchor,
arguments,
partition_by,
order_by,
bounds,
))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => make_substrait_like_expr(
*case_insensitive,
*negated,
expr,
pattern,
*escape_char,
schema,
col_ref_offset,
extension_info,
),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported expression: {expr:?}"
))),
}
}
fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
let default_nullability = r#type::Nullability::Required as i32;
match dt {
DataType::Null => Err(DataFusionError::Internal(
"Null cast is not valid".to_string(),
)),
DataType::Boolean => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::UInt8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int16 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::UInt16 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::UInt32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::UInt64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
// Float16 is not supported in Substrait
DataType::Float32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Float64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
// Timezone is ignored.
DataType::Timestamp(unit, _) => {
let type_variation_reference = match unit {
TimeUnit::Second => TIMESTAMP_SECOND_TYPE_REF,
TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_REF,
TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_REF,
TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_REF,
};
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference,
nullability: default_nullability,
})),
})
}
DataType::Date32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_32_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Date64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_64_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Binary => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary {
length: *length,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::LargeBinary => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Utf8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::LargeUtf8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::List(inner) => {
let inner_type = to_substrait_type(inner.data_type())?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::List(Box::new(r#type::List {
r#type: Some(Box::new(inner_type)),
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
}))),
})
}
DataType::LargeList(inner) => {
let inner_type = to_substrait_type(inner.data_type())?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::List(Box::new(r#type::List {
r#type: Some(Box::new(inner_type)),
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
}))),
})
}
DataType::Struct(fields) => {
let field_types = fields
.iter()
.map(|field| to_substrait_type(field.data_type()))
.collect::<Result<Vec<_>>>()?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Struct(r#type::Struct {
types: field_types,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})
}
DataType::Decimal128(p, s) => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
type_variation_reference: DECIMAL_128_TYPE_REF,
nullability: default_nullability,
scale: *s as i32,
precision: *p as i32,
})),
}),
DataType::Decimal256(p, s) => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
type_variation_reference: DECIMAL_256_TYPE_REF,
nullability: default_nullability,
scale: *s as i32,
precision: *p as i32,
})),
}),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported cast type: {dt:?}"
))),
}
}
#[allow(deprecated)]
fn make_substrait_window_function(
function_reference: u32,
arguments: Vec<FunctionArgument>,
partitions: Vec<Expression>,
sorts: Vec<SortField>,
bounds: (Bound, Bound),
) -> Expression {
Expression {
rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction {
function_reference,
arguments,
partitions,
sorts,
options: vec![],
output_type: None,
phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED
invocation: 0, // TODO: fix
lower_bound: Some(bounds.0),
upper_bound: Some(bounds.1),
args: vec![],
})),
}
}
#[allow(deprecated)]
#[allow(clippy::too_many_arguments)]
fn make_substrait_like_expr(
ignore_case: bool,
negated: bool,
expr: &Expr,
pattern: &Expr,
escape_char: Option<char>,
schema: &DFSchemaRef,
col_ref_offset: usize,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
let function_anchor = if ignore_case {
_register_function("ilike".to_string(), extension_info)
} else {
_register_function("like".to_string(), extension_info)
};
let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?;
let escape_char =
to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?;
let arguments = vec![
FunctionArgument {
arg_type: Some(ArgType::Value(expr)),
},
FunctionArgument {
arg_type: Some(ArgType::Value(pattern)),
},
FunctionArgument {
arg_type: Some(ArgType::Value(escape_char)),
},
];
let substrait_like = Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
};
if negated {
let function_anchor = _register_function("not".to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_like)),
}],
output_type: None,
args: vec![],
options: vec![],
})),
})
} else {
Ok(substrait_like)
}
}
fn to_substrait_bound(bound: &WindowFrameBound) -> Bound {
match bound {
WindowFrameBound::CurrentRow => Bound {
kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})),
},
WindowFrameBound::Preceding(s) => match s {
ScalarValue::UInt8(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::UInt16(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::UInt32(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::UInt64(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::Int8(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::Int16(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::Int32(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v as i64,
})),
},
ScalarValue::Int64(Some(v)) => Bound {
kind: Some(BoundKind::Preceding(SubstraitBound::Preceding {
offset: *v,
})),
},
_ => Bound {
kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
},
},
WindowFrameBound::Following(s) => match s {
ScalarValue::UInt8(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::UInt16(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::UInt32(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::UInt64(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::Int8(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::Int16(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::Int32(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v as i64,
})),
},
ScalarValue::Int64(Some(v)) => Bound {
kind: Some(BoundKind::Following(SubstraitBound::Following {
offset: *v,
})),
},
_ => Bound {
kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})),
},
},
}
}
fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
Ok((
to_substrait_bound(&window_frame.start_bound),
to_substrait_bound(&window_frame.end_bound),
))
}
fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
let (literal_type, type_variation_reference) = match value {
ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF),
ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF),
ScalarValue::UInt8(Some(n)) => {
(LiteralType::I8(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
}
ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_REF),
ScalarValue::UInt16(Some(n)) => {
(LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
}
ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_REF),
ScalarValue::UInt32(Some(n)) => {
(LiteralType::I32(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
}
ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_REF),
ScalarValue::UInt64(Some(n)) => {
(LiteralType::I64(*n as i64), UNSIGNED_INTEGER_TYPE_REF)
}
ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_REF),
ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_REF),
ScalarValue::TimestampSecond(Some(t), _) => {
(LiteralType::Timestamp(*t), TIMESTAMP_SECOND_TYPE_REF)
}
ScalarValue::TimestampMillisecond(Some(t), _) => {
(LiteralType::Timestamp(*t), TIMESTAMP_MILLI_TYPE_REF)
}
ScalarValue::TimestampMicrosecond(Some(t), _) => {
(LiteralType::Timestamp(*t), TIMESTAMP_MICRO_TYPE_REF)
}
ScalarValue::TimestampNanosecond(Some(t), _) => {
(LiteralType::Timestamp(*t), TIMESTAMP_NANO_TYPE_REF)
}
ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF),
// Date64 literal is not supported in Substrait
ScalarValue::Binary(Some(b)) => {
(LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
}
ScalarValue::LargeBinary(Some(b)) => {
(LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_REF)
}
ScalarValue::FixedSizeBinary(_, Some(b)) => {
(LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_REF)
}
ScalarValue::Utf8(Some(s)) => {
(LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_REF)
}
ScalarValue::LargeUtf8(Some(s)) => {
(LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_REF)
}
ScalarValue::Decimal128(v, p, s) if v.is_some() => (
LiteralType::Decimal(Decimal {
value: v.unwrap().to_le_bytes().to_vec(),
precision: *p as i32,
scale: *s as i32,
}),
DECIMAL_128_TYPE_REF,
),
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
};
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
type_variation_reference,
literal_type: Some(literal_type),
})),
})
}
fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Boolean(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::TimestampSecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_SECOND_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampMillisecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_MILLI_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampMicrosecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_MICRO_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::TimestampNanosecond(None, _) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
type_variation_reference: TIMESTAMP_NANO_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_32_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Date(r#type::Date {
type_variation_reference: DATE_64_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::FixedSizeBinary(_, None) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}))
}
ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Decimal128(None, p, s) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}))
}
// TODO: Extend support for remaining data types
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {v:?}"
))),
}
}
fn substrait_sort_field(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<SortField> {
match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let e = to_substrait_rex(expr, schema, 0, extension_info)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
(false, true) => SortDirection::DescNullsFirst,
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
expr: Some(e),
sort_kind: Some(SortKind::Direction(d as i32)),
})
}
_ => Err(DataFusionError::NotImplemented(format!(
"Expecting sort expression but got {expr:?}"
))),
}
}
fn substrait_field_ref(index: usize) -> Result<Expression> {
Ok(Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
reference_type: Some(reference_segment::ReferenceType::StructField(
Box::new(reference_segment::StructField {
field: index as i32,
child: None,
}),
)),
})),
root_type: None,
}))),
})
}
#[cfg(test)]
mod test {
use crate::logical_plan::consumer::from_substrait_literal;
use super::*;
#[test]
fn round_trip_literals() -> Result<()> {
round_trip_literal(ScalarValue::Boolean(None))?;
round_trip_literal(ScalarValue::Boolean(Some(true)))?;
round_trip_literal(ScalarValue::Boolean(Some(false)))?;
round_trip_literal(ScalarValue::Int8(None))?;
round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?;
round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?;
round_trip_literal(ScalarValue::UInt8(None))?;
round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?;
round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?;
round_trip_literal(ScalarValue::Int16(None))?;
round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?;
round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?;
round_trip_literal(ScalarValue::UInt16(None))?;
round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?;
round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?;
round_trip_literal(ScalarValue::Int32(None))?;
round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?;
round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?;
round_trip_literal(ScalarValue::UInt32(None))?;
round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?;
round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?;
round_trip_literal(ScalarValue::Int64(None))?;
round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?;
round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?;
round_trip_literal(ScalarValue::UInt64(None))?;
round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;
Ok(())
}
fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {scalar:?}");
let substrait = to_substrait_literal(&scalar)?;
let Expression { rex_type: Some(RexType::Literal(substrait_literal)) } = substrait else {
panic!("Expected Literal expression, got {substrait:?}");
};
let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}
}