in cpp/src/arrow/engine/substrait/relation_internal.cc [377:981]
Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
static bool dataset_init = false;
if (!dataset_init) {
dataset_init = true;
dataset::internal::Initialize();
}
switch (rel.rel_type_case()) {
case substrait::Rel::RelTypeCase::kRead: {
const auto& read = rel.read();
RETURN_NOT_OK(CheckRelCommon(read, conversion_options));
// Get the base schema for the read relation
ARROW_ASSIGN_OR_RAISE(auto base_schema,
FromProto(read.base_schema(), ext_set, conversion_options));
auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->use_threads = true;
scan_options->add_augmented_fields = false;
if (read.has_filter()) {
ARROW_ASSIGN_OR_RAISE(scan_options->filter,
FromProto(read.filter(), ext_set, conversion_options));
}
if (read.has_projection()) {
return Status::NotImplemented("substrait::ReadRel::projection");
}
if (read.has_named_table()) {
if (!conversion_options.named_table_provider) {
return Status::Invalid(
"plan contained a named table but a NamedTableProvider has not been "
"configured");
}
if (read.named_table().names().empty()) {
return Status::Invalid("names for NamedTable not provided");
}
const NamedTableProvider& named_table_provider =
conversion_options.named_table_provider;
const substrait::ReadRel::NamedTable& named_table = read.named_table();
std::vector<std::string> table_names(named_table.names().begin(),
named_table.names().end());
ARROW_ASSIGN_OR_RAISE(acero::Declaration source_decl,
named_table_provider(table_names, *base_schema));
if (!source_decl.IsValid()) {
return Status::Invalid("Invalid NamedTable Source");
}
return ProcessEmit(read, DeclarationInfo{std::move(source_decl), base_schema},
base_schema);
}
if (!read.has_local_files()) {
return Status::NotImplemented(
"substrait::ReadRel with read_type other than LocalFiles");
}
if (read.local_files().has_advanced_extension()) {
return Status::NotImplemented(
"substrait::ReadRel::LocalFiles::advanced_extension");
}
std::shared_ptr<dataset::FileFormat> format;
auto filesystem = std::make_shared<fs::LocalFileSystem>();
std::vector<fs::FileInfo> files;
for (const auto& item : read.local_files().items()) {
// Validate properties of the `FileOrFiles` item
if (item.partition_index() != 0) {
return Status::NotImplemented(
"non-default "
"substrait::ReadRel::LocalFiles::FileOrFiles::partition_index");
}
if (item.start() != 0) {
return Status::NotImplemented(
"non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset");
}
if (item.length() != 0) {
return Status::NotImplemented(
"non-default substrait::ReadRel::LocalFiles::FileOrFiles::length");
}
// Extract and parse the read relation's source URI
::arrow::util::Uri item_uri;
switch (item.path_type_case()) {
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath:
RETURN_NOT_OK(item_uri.Parse(item.uri_path()));
break;
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile:
RETURN_NOT_OK(item_uri.Parse(item.uri_file()));
break;
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder:
RETURN_NOT_OK(item_uri.Parse(item.uri_folder()));
break;
default:
RETURN_NOT_OK(item_uri.Parse(item.uri_path_glob()));
break;
}
// Validate the URI before processing
if (!item_uri.is_file_scheme()) {
return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
item_uri.ToString(),
") does not have file scheme (file:///)");
}
if (item_uri.port() != -1) {
return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
item_uri.ToString(),
") should not have a port number in path");
}
if (!item_uri.query_string().empty()) {
return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
item_uri.ToString(),
") should not have a query string in path");
}
switch (item.file_format_case()) {
case substrait::ReadRel::LocalFiles::FileOrFiles::kParquet:
format = std::make_shared<dataset::ParquetFileFormat>();
break;
case substrait::ReadRel::LocalFiles::FileOrFiles::kArrow:
format = std::make_shared<dataset::IpcFileFormat>();
break;
default:
return Status::NotImplemented(
"unsupported file format ",
"(see substrait::ReadRel::LocalFiles::FileOrFiles::file_format)");
}
// Handle the URI as appropriate
switch (item.path_type_case()) {
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: {
files.emplace_back(item_uri.path(), fs::FileType::File);
break;
}
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: {
RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files));
break;
}
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: {
ARROW_ASSIGN_OR_RAISE(auto file_info,
filesystem->GetFileInfo(item_uri.path()));
switch (file_info.type()) {
case fs::FileType::File: {
files.push_back(std::move(file_info));
break;
}
case fs::FileType::Directory: {
RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files));
break;
}
case fs::FileType::NotFound:
return Status::Invalid("Unable to find file for URI path");
case fs::FileType::Unknown:
[[fallthrough]];
default:
return Status::NotImplemented("URI path is of unknown file type.");
}
break;
}
case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: {
ARROW_ASSIGN_OR_RAISE(auto globbed_files,
fs::internal::GlobFiles(filesystem, item_uri.path()));
std::move(globbed_files.begin(), globbed_files.end(),
std::back_inserter(files));
break;
}
default: {
return Status::Invalid("Unrecognized file type in LocalFiles");
}
}
}
ARROW_ASSIGN_OR_RAISE(auto ds_factory, dataset::FileSystemDatasetFactory::Make(
std::move(filesystem), std::move(files),
std::move(format), {}));
ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema));
DeclarationInfo scan_declaration{
acero::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}},
base_schema};
return ProcessEmit(read, scan_declaration, base_schema);
}
case substrait::Rel::RelTypeCase::kFilter: {
const auto& filter = rel.filter();
RETURN_NOT_OK(CheckRelCommon(filter, conversion_options));
if (!filter.has_input()) {
return Status::Invalid("substrait::FilterRel with no input relation");
}
ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(filter.input(), ext_set, conversion_options));
if (!filter.has_condition()) {
return Status::Invalid("substrait::FilterRel with no condition expression");
}
ARROW_ASSIGN_OR_RAISE(auto condition,
FromProto(filter.condition(), ext_set, conversion_options));
DeclarationInfo filter_declaration{
acero::Declaration::Sequence({
std::move(input.declaration),
{"filter", acero::FilterNodeOptions{std::move(condition)}},
}),
input.output_schema};
return ProcessEmit(filter, filter_declaration, input.output_schema);
}
case substrait::Rel::RelTypeCase::kProject: {
const auto& project = rel.project();
RETURN_NOT_OK(CheckRelCommon(project, conversion_options));
if (!project.has_input()) {
return Status::Invalid("substrait::ProjectRel with no input relation");
}
ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(project.input(), ext_set, conversion_options));
// NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces
// them. Therefore, we need to prefix all the current columns for compatibility.
std::vector<compute::Expression> expressions;
int num_columns = input.output_schema->num_fields();
expressions.reserve(num_columns + project.expressions().size());
for (int i = 0; i < num_columns; i++) {
expressions.emplace_back(compute::field_ref(FieldRef(i)));
}
int i = 0;
auto project_schema = input.output_schema;
for (const auto& expr : project.expressions()) {
std::shared_ptr<Field> project_field;
ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
FromProto(expr, ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(compute::Expression bound_expr,
des_expr.Bind(*input.output_schema));
if (auto* expr_call = bound_expr.call()) {
project_field = field(expr_call->function_name,
expr_call->kernel->signature->out_type().type());
} else if (auto* field_ref = des_expr.field_ref()) {
ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
field_ref->FindOne(*input.output_schema));
ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema));
} else if (auto* literal = des_expr.literal()) {
project_field = field("field_" + ToChars(num_columns + i), literal->type());
}
ARROW_ASSIGN_OR_RAISE(
project_schema,
project_schema->AddField(num_columns + i, std::move(project_field)));
i++;
expressions.emplace_back(des_expr);
}
DeclarationInfo project_declaration{
acero::Declaration::Sequence({
std::move(input.declaration),
{"project", acero::ProjectNodeOptions{std::move(expressions)}},
}),
project_schema};
return ProcessEmit(project, project_declaration, project_schema);
}
case substrait::Rel::RelTypeCase::kJoin: {
const auto& join = rel.join();
RETURN_NOT_OK(CheckRelCommon(join, conversion_options));
if (!join.has_left()) {
return Status::Invalid("substrait::JoinRel with no left relation");
}
if (!join.has_right()) {
return Status::Invalid("substrait::JoinRel with no right relation");
}
acero::JoinType join_type;
switch (join.type()) {
case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED:
return Status::NotImplemented("Unspecified join type is not supported");
case substrait::JoinRel::JOIN_TYPE_INNER:
join_type = acero::JoinType::INNER;
break;
case substrait::JoinRel::JOIN_TYPE_OUTER:
join_type = acero::JoinType::FULL_OUTER;
break;
case substrait::JoinRel::JOIN_TYPE_LEFT:
join_type = acero::JoinType::LEFT_OUTER;
break;
case substrait::JoinRel::JOIN_TYPE_RIGHT:
join_type = acero::JoinType::RIGHT_OUTER;
break;
case substrait::JoinRel::JOIN_TYPE_SEMI:
join_type = acero::JoinType::LEFT_SEMI;
break;
case substrait::JoinRel::JOIN_TYPE_ANTI:
join_type = acero::JoinType::LEFT_ANTI;
break;
default:
return Status::Invalid("Unsupported join type");
}
ARROW_ASSIGN_OR_RAISE(auto left,
FromProto(join.left(), ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(auto right,
FromProto(join.right(), ext_set, conversion_options));
if (!join.has_expression()) {
return Status::Invalid("substrait::JoinRel with no expression");
}
ARROW_ASSIGN_OR_RAISE(auto expression,
FromProto(join.expression(), ext_set, conversion_options));
const auto* callptr = expression.call();
if (!callptr) {
return Status::Invalid(
"A join rel's expression must be a simple equality between keys but got ",
expression.ToString());
}
acero::JoinKeyCmp join_key_cmp;
if (callptr->function_name == "equal") {
join_key_cmp = acero::JoinKeyCmp::EQ;
} else if (callptr->function_name == "is_not_distinct_from") {
join_key_cmp = acero::JoinKeyCmp::IS;
} else {
return Status::Invalid(
"Only `equal` or `is_not_distinct_from` are supported for join key "
"comparison but got ",
callptr->function_name);
}
// Create output schema from left, right relations and join keys
FieldVector combined_fields = left.output_schema->fields();
const FieldVector& right_fields = right.output_schema->fields();
combined_fields.insert(combined_fields.end(), right_fields.begin(),
right_fields.end());
std::shared_ptr<Schema> join_schema = schema(std::move(combined_fields));
// adjust the join_keys according to Substrait definition where
// the join fields are defined by considering the `join_schema` which
// is the combination of the left and right relation schema.
// TODO: ARROW-16624 Add Suffix support for Substrait
const auto* left_keys = callptr->arguments[0].field_ref();
const auto* right_keys = callptr->arguments[1].field_ref();
// Validating JoinKeys
if (!left_keys || !right_keys) {
return Status::Invalid(
"join condition must include references to both left and right inputs");
}
int num_left_fields = left.output_schema->num_fields();
const auto* right_field_path = right_keys->field_path();
std::vector<int> adjusted_field_indices(right_field_path->indices());
adjusted_field_indices[0] -= num_left_fields;
FieldPath adjusted_right_keys(adjusted_field_indices);
acero::HashJoinNodeOptions join_options{{std::move(*left_keys)},
{std::move(adjusted_right_keys)}};
join_options.join_type = join_type;
join_options.key_cmp = {join_key_cmp};
acero::Declaration join_dec{"hashjoin", std::move(join_options)};
join_dec.inputs.emplace_back(std::move(left.declaration));
join_dec.inputs.emplace_back(std::move(right.declaration));
DeclarationInfo join_declaration{std::move(join_dec), join_schema};
return ProcessEmit(join, join_declaration, join_schema);
}
case substrait::Rel::RelTypeCase::kFetch: {
const auto& fetch = rel.fetch();
RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options));
if (!fetch.has_input()) {
return Status::Invalid("substrait::FetchRel with no input relation");
}
ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(fetch.input(), ext_set, conversion_options));
int64_t offset = fetch.offset();
int64_t count = fetch.count();
acero::Declaration fetch_dec{
"fetch", {input.declaration}, acero::FetchNodeOptions(offset, count)};
DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema};
return ProcessEmit(fetch, std::move(fetch_declaration),
fetch_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kSort: {
const auto& sort = rel.sort();
RETURN_NOT_OK(CheckRelCommon(sort, conversion_options));
if (!sort.has_input()) {
return Status::Invalid("substrait::SortRel with no input relation");
}
ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(sort.input(), ext_set, conversion_options));
if (sort.sorts_size() == 0) {
return Status::Invalid("substrait::SortRel with no sorts");
}
std::vector<compute::SortKey> sort_keys;
sort_keys.reserve(sort.sorts_size());
// Substrait allows null placement to differ for each field. Acero expects it to
// be consistent across all fields. So we grab the null placement from the first
// key and verify all other keys have the same null placement
std::optional<SortBehavior> sample_sort_behavior;
for (const auto& sort : sort.sorts()) {
ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior,
SortBehavior::Make(sort.direction()));
if (sample_sort_behavior) {
if (sample_sort_behavior->null_placement != sort_behavior.null_placement) {
return Status::NotImplemented(
"substrait::SortRel with ordering with mixed null placement");
}
} else {
sample_sort_behavior = sort_behavior;
}
if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) {
return Status::NotImplemented("substrait::SortRel with custom sort function");
}
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(sort.expr(), ext_set, conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order));
} else {
return Status::Invalid("Sort key expressions must be a direct reference.");
}
}
DCHECK(sample_sort_behavior.has_value());
acero::Declaration sort_dec{
"order_by",
{input.declaration},
acero::OrderByNodeOptions(compute::Ordering(
std::move(sort_keys), sample_sort_behavior->null_placement))};
DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema};
return ProcessEmit(sort, std::move(sort_declaration),
sort_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kAggregate: {
const auto& aggregate = rel.aggregate();
RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options));
if (!aggregate.has_input()) {
return Status::Invalid("substrait::AggregateRel with no input relation");
}
ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(aggregate.input(), ext_set, conversion_options));
if (aggregate.groupings_size() > 1) {
return Status::NotImplemented(
"Grouping sets not supported. AggregateRel::groupings may not have more "
"than one item");
}
// prepare output schema from aggregates
auto input_schema = input.output_schema;
std::vector<FieldRef> keys;
if (aggregate.groupings_size() > 0) {
const substrait::AggregateRel::Grouping& group = aggregate.groupings(0);
int grouping_expr_size = group.grouping_expressions_size();
keys.reserve(grouping_expr_size);
for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) {
ARROW_ASSIGN_OR_RAISE(
compute::Expression expr,
FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
keys.emplace_back(std::move(*field_ref));
} else {
return Status::Invalid(
"The grouping expression for an aggregate must be a direct reference.");
}
}
}
const int measure_size = aggregate.measures_size();
std::vector<compute::Aggregate> aggregates;
aggregates.reserve(measure_size);
for (int measure_id = 0; measure_id < measure_size; measure_id++) {
const auto& agg_measure = aggregate.measures(measure_id);
ARROW_ASSIGN_OR_RAISE(
auto aggregate,
internal::ParseAggregateMeasure(agg_measure, ext_set, conversion_options,
/*is_hash=*/!keys.empty(), input_schema));
aggregates.push_back(std::move(aggregate));
}
ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
acero::aggregate::MakeOutputSchema(
input_schema, keys, /*segment_keys=*/{}, aggregates));
ARROW_ASSIGN_OR_RAISE(
auto aggregate_declaration,
internal::MakeAggregateDeclaration(std::move(input.declaration),
aggregate_schema, std::move(aggregates),
std::move(keys), /*segment_keys=*/{}));
return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration),
std::move(aggregate_schema));
}
case substrait::Rel::RelTypeCase::kExtensionLeaf:
case substrait::Rel::RelTypeCase::kExtensionSingle:
case substrait::Rel::RelTypeCase::kExtensionMulti: {
std::vector<DeclarationInfo> ext_rel_inputs;
ARROW_ASSIGN_OR_RAISE(
auto ext_decl_info,
GetExtensionInfo(rel, ext_set, conversion_options, &ext_rel_inputs));
auto ext_common_opt = GetExtensionRelCommon(rel);
bool has_emit = ext_common_opt && ext_common_opt->emit_kind_case() ==
substrait::RelCommon::EmitKindCase::kEmit;
// Set up the emit order - an ordered list of indices that specifies an output
// mapping as expected by Substrait. This is a sublist of [0..N), where N is the
// total number of input fields across all inputs of the relation, that selects
// from these input fields.
if (has_emit) {
std::vector<int> emit_order;
// the emit order is defined in the Substrait plan - pick it up
const auto& emit_info = ext_common_opt->emit();
emit_order.reserve(emit_info.output_mapping_size());
for (const auto& emit_idx : emit_info.output_mapping()) {
emit_order.push_back(emit_idx);
}
return ProcessExtensionEmit(std::move(ext_decl_info), emit_order);
} else {
return ext_decl_info;
}
}
case substrait::Rel::RelTypeCase::kSet: {
const auto& set = rel.set();
RETURN_NOT_OK(CheckRelCommon(set, conversion_options));
if (set.inputs_size() < 2) {
return Status::Invalid(
"substrait::SetRel with inadequate number of input relations, ",
set.inputs_size());
}
substrait::SetRel_SetOp op = set.op();
// Note: at the moment Acero only supports UNION_ALL operation
switch (op) {
case substrait::SetRel::SET_OP_UNSPECIFIED:
case substrait::SetRel::SET_OP_MINUS_PRIMARY:
case substrait::SetRel::SET_OP_MINUS_MULTISET:
case substrait::SetRel::SET_OP_INTERSECTION_PRIMARY:
case substrait::SetRel::SET_OP_INTERSECTION_MULTISET:
case substrait::SetRel::SET_OP_UNION_DISTINCT:
return Status::NotImplemented(
"NotImplemented union type : ",
EnumToString(op, *substrait::SetRel_SetOp_descriptor()));
case substrait::SetRel::SET_OP_UNION_ALL:
break;
default:
return Status::Invalid("Unknown union type");
}
int input_size = set.inputs_size();
acero::Declaration union_declr{"union", acero::ExecNodeOptions{}};
std::shared_ptr<Schema> union_schema;
for (int input_id = 0; input_id < input_size; input_id++) {
ARROW_ASSIGN_OR_RAISE(
auto input, FromProto(set.inputs(input_id), ext_set, conversion_options));
union_declr.inputs.emplace_back(std::move(input.declaration));
if (union_schema == nullptr) {
union_schema = input.output_schema;
}
}
auto set_declaration = DeclarationInfo{union_declr, union_schema};
return ProcessEmit(set, set_declaration, union_schema);
}
default:
break;
}
return Status::NotImplemented(
"conversion to arrow::acero::Declaration from Substrait relation ",
rel.DebugString());
}