Result FromProto()

in cpp/src/arrow/engine/substrait/relation_internal.cc [377:981]


Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set,
                                  const ConversionOptions& conversion_options) {
  static bool dataset_init = false;
  if (!dataset_init) {
    dataset_init = true;
    dataset::internal::Initialize();
  }

  switch (rel.rel_type_case()) {
    case substrait::Rel::RelTypeCase::kRead: {
      const auto& read = rel.read();
      RETURN_NOT_OK(CheckRelCommon(read, conversion_options));

      // Get the base schema for the read relation
      ARROW_ASSIGN_OR_RAISE(auto base_schema,
                            FromProto(read.base_schema(), ext_set, conversion_options));

      auto scan_options = std::make_shared<dataset::ScanOptions>();
      scan_options->use_threads = true;
      scan_options->add_augmented_fields = false;

      if (read.has_filter()) {
        ARROW_ASSIGN_OR_RAISE(scan_options->filter,
                              FromProto(read.filter(), ext_set, conversion_options));
      }

      if (read.has_projection()) {
        return Status::NotImplemented("substrait::ReadRel::projection");
      }

      if (read.has_named_table()) {
        if (!conversion_options.named_table_provider) {
          return Status::Invalid(
              "plan contained a named table but a NamedTableProvider has not been "
              "configured");
        }

        if (read.named_table().names().empty()) {
          return Status::Invalid("names for NamedTable not provided");
        }

        const NamedTableProvider& named_table_provider =
            conversion_options.named_table_provider;
        const substrait::ReadRel::NamedTable& named_table = read.named_table();
        std::vector<std::string> table_names(named_table.names().begin(),
                                             named_table.names().end());
        ARROW_ASSIGN_OR_RAISE(acero::Declaration source_decl,
                              named_table_provider(table_names, *base_schema));

        if (!source_decl.IsValid()) {
          return Status::Invalid("Invalid NamedTable Source");
        }

        return ProcessEmit(read, DeclarationInfo{std::move(source_decl), base_schema},
                           base_schema);
      }

      if (!read.has_local_files()) {
        return Status::NotImplemented(
            "substrait::ReadRel with read_type other than LocalFiles");
      }

      if (read.local_files().has_advanced_extension()) {
        return Status::NotImplemented(
            "substrait::ReadRel::LocalFiles::advanced_extension");
      }

      std::shared_ptr<dataset::FileFormat> format;
      auto filesystem = std::make_shared<fs::LocalFileSystem>();
      std::vector<fs::FileInfo> files;

      for (const auto& item : read.local_files().items()) {
        // Validate properties of the `FileOrFiles` item
        if (item.partition_index() != 0) {
          return Status::NotImplemented(
              "non-default "
              "substrait::ReadRel::LocalFiles::FileOrFiles::partition_index");
        }

        if (item.start() != 0) {
          return Status::NotImplemented(
              "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset");
        }

        if (item.length() != 0) {
          return Status::NotImplemented(
              "non-default substrait::ReadRel::LocalFiles::FileOrFiles::length");
        }

        // Extract and parse the read relation's source URI
        ::arrow::util::Uri item_uri;
        switch (item.path_type_case()) {
          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath:
            RETURN_NOT_OK(item_uri.Parse(item.uri_path()));
            break;

          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile:
            RETURN_NOT_OK(item_uri.Parse(item.uri_file()));
            break;

          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder:
            RETURN_NOT_OK(item_uri.Parse(item.uri_folder()));
            break;

          default:
            RETURN_NOT_OK(item_uri.Parse(item.uri_path_glob()));
            break;
        }

        // Validate the URI before processing
        if (!item_uri.is_file_scheme()) {
          return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
                                        item_uri.ToString(),
                                        ") does not have file scheme (file:///)");
        }

        if (item_uri.port() != -1) {
          return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
                                        item_uri.ToString(),
                                        ") should not have a port number in path");
        }

        if (!item_uri.query_string().empty()) {
          return Status::NotImplemented("substrait::ReadRel::LocalFiles item (",
                                        item_uri.ToString(),
                                        ") should not have a query string in path");
        }

        switch (item.file_format_case()) {
          case substrait::ReadRel::LocalFiles::FileOrFiles::kParquet:
            format = std::make_shared<dataset::ParquetFileFormat>();
            break;
          case substrait::ReadRel::LocalFiles::FileOrFiles::kArrow:
            format = std::make_shared<dataset::IpcFileFormat>();
            break;
          default:
            return Status::NotImplemented(
                "unsupported file format ",
                "(see substrait::ReadRel::LocalFiles::FileOrFiles::file_format)");
        }

        // Handle the URI as appropriate
        switch (item.path_type_case()) {
          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: {
            files.emplace_back(item_uri.path(), fs::FileType::File);
            break;
          }

          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: {
            RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files));
            break;
          }

          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: {
            ARROW_ASSIGN_OR_RAISE(auto file_info,
                                  filesystem->GetFileInfo(item_uri.path()));

            switch (file_info.type()) {
              case fs::FileType::File: {
                files.push_back(std::move(file_info));
                break;
              }
              case fs::FileType::Directory: {
                RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files));
                break;
              }
              case fs::FileType::NotFound:
                return Status::Invalid("Unable to find file for URI path");
              case fs::FileType::Unknown:
                [[fallthrough]];
              default:
                return Status::NotImplemented("URI path is of unknown file type.");
            }
            break;
          }

          case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: {
            ARROW_ASSIGN_OR_RAISE(auto globbed_files,
                                  fs::internal::GlobFiles(filesystem, item_uri.path()));
            std::move(globbed_files.begin(), globbed_files.end(),
                      std::back_inserter(files));
            break;
          }

          default: {
            return Status::Invalid("Unrecognized file type in LocalFiles");
          }
        }
      }

      ARROW_ASSIGN_OR_RAISE(auto ds_factory, dataset::FileSystemDatasetFactory::Make(
                                                 std::move(filesystem), std::move(files),
                                                 std::move(format), {}));

      ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema));

      DeclarationInfo scan_declaration{
          acero::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}},
          base_schema};

      return ProcessEmit(read, scan_declaration, base_schema);
    }

    case substrait::Rel::RelTypeCase::kFilter: {
      const auto& filter = rel.filter();
      RETURN_NOT_OK(CheckRelCommon(filter, conversion_options));

      if (!filter.has_input()) {
        return Status::Invalid("substrait::FilterRel with no input relation");
      }
      ARROW_ASSIGN_OR_RAISE(auto input,
                            FromProto(filter.input(), ext_set, conversion_options));

      if (!filter.has_condition()) {
        return Status::Invalid("substrait::FilterRel with no condition expression");
      }
      ARROW_ASSIGN_OR_RAISE(auto condition,
                            FromProto(filter.condition(), ext_set, conversion_options));
      DeclarationInfo filter_declaration{
          acero::Declaration::Sequence({
              std::move(input.declaration),
              {"filter", acero::FilterNodeOptions{std::move(condition)}},
          }),
          input.output_schema};

      return ProcessEmit(filter, filter_declaration, input.output_schema);
    }

    case substrait::Rel::RelTypeCase::kProject: {
      const auto& project = rel.project();
      RETURN_NOT_OK(CheckRelCommon(project, conversion_options));
      if (!project.has_input()) {
        return Status::Invalid("substrait::ProjectRel with no input relation");
      }
      ARROW_ASSIGN_OR_RAISE(auto input,
                            FromProto(project.input(), ext_set, conversion_options));

      // NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces
      // them. Therefore, we need to prefix all the current columns for compatibility.
      std::vector<compute::Expression> expressions;
      int num_columns = input.output_schema->num_fields();
      expressions.reserve(num_columns + project.expressions().size());
      for (int i = 0; i < num_columns; i++) {
        expressions.emplace_back(compute::field_ref(FieldRef(i)));
      }

      int i = 0;
      auto project_schema = input.output_schema;
      for (const auto& expr : project.expressions()) {
        std::shared_ptr<Field> project_field;
        ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
                              FromProto(expr, ext_set, conversion_options));
        ARROW_ASSIGN_OR_RAISE(compute::Expression bound_expr,
                              des_expr.Bind(*input.output_schema));
        if (auto* expr_call = bound_expr.call()) {
          project_field = field(expr_call->function_name,
                                expr_call->kernel->signature->out_type().type());
        } else if (auto* field_ref = des_expr.field_ref()) {
          ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
                                field_ref->FindOne(*input.output_schema));
          ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema));
        } else if (auto* literal = des_expr.literal()) {
          project_field = field("field_" + ToChars(num_columns + i), literal->type());
        }
        ARROW_ASSIGN_OR_RAISE(
            project_schema,
            project_schema->AddField(num_columns + i, std::move(project_field)));
        i++;
        expressions.emplace_back(des_expr);
      }

      DeclarationInfo project_declaration{
          acero::Declaration::Sequence({
              std::move(input.declaration),
              {"project", acero::ProjectNodeOptions{std::move(expressions)}},
          }),
          project_schema};

      return ProcessEmit(project, project_declaration, project_schema);
    }

    case substrait::Rel::RelTypeCase::kJoin: {
      const auto& join = rel.join();
      RETURN_NOT_OK(CheckRelCommon(join, conversion_options));

      if (!join.has_left()) {
        return Status::Invalid("substrait::JoinRel with no left relation");
      }

      if (!join.has_right()) {
        return Status::Invalid("substrait::JoinRel with no right relation");
      }

      acero::JoinType join_type;
      switch (join.type()) {
        case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED:
          return Status::NotImplemented("Unspecified join type is not supported");
        case substrait::JoinRel::JOIN_TYPE_INNER:
          join_type = acero::JoinType::INNER;
          break;
        case substrait::JoinRel::JOIN_TYPE_OUTER:
          join_type = acero::JoinType::FULL_OUTER;
          break;
        case substrait::JoinRel::JOIN_TYPE_LEFT:
          join_type = acero::JoinType::LEFT_OUTER;
          break;
        case substrait::JoinRel::JOIN_TYPE_RIGHT:
          join_type = acero::JoinType::RIGHT_OUTER;
          break;
        case substrait::JoinRel::JOIN_TYPE_SEMI:
          join_type = acero::JoinType::LEFT_SEMI;
          break;
        case substrait::JoinRel::JOIN_TYPE_ANTI:
          join_type = acero::JoinType::LEFT_ANTI;
          break;
        default:
          return Status::Invalid("Unsupported join type");
      }

      ARROW_ASSIGN_OR_RAISE(auto left,
                            FromProto(join.left(), ext_set, conversion_options));
      ARROW_ASSIGN_OR_RAISE(auto right,
                            FromProto(join.right(), ext_set, conversion_options));

      if (!join.has_expression()) {
        return Status::Invalid("substrait::JoinRel with no expression");
      }

      ARROW_ASSIGN_OR_RAISE(auto expression,
                            FromProto(join.expression(), ext_set, conversion_options));

      const auto* callptr = expression.call();
      if (!callptr) {
        return Status::Invalid(
            "A join rel's expression must be a simple equality between keys but got ",
            expression.ToString());
      }

      acero::JoinKeyCmp join_key_cmp;
      if (callptr->function_name == "equal") {
        join_key_cmp = acero::JoinKeyCmp::EQ;
      } else if (callptr->function_name == "is_not_distinct_from") {
        join_key_cmp = acero::JoinKeyCmp::IS;
      } else {
        return Status::Invalid(
            "Only `equal` or `is_not_distinct_from` are supported for join key "
            "comparison but got ",
            callptr->function_name);
      }

      // Create output schema from left, right relations and join keys
      FieldVector combined_fields = left.output_schema->fields();
      const FieldVector& right_fields = right.output_schema->fields();
      combined_fields.insert(combined_fields.end(), right_fields.begin(),
                             right_fields.end());
      std::shared_ptr<Schema> join_schema = schema(std::move(combined_fields));

      // adjust the join_keys according to Substrait definition where
      // the join fields are defined by considering the `join_schema` which
      // is the combination of the left and right relation schema.

      // TODO: ARROW-16624 Add Suffix support for Substrait
      const auto* left_keys = callptr->arguments[0].field_ref();
      const auto* right_keys = callptr->arguments[1].field_ref();
      // Validating JoinKeys
      if (!left_keys || !right_keys) {
        return Status::Invalid(
            "join condition must include references to both left and right inputs");
      }
      int num_left_fields = left.output_schema->num_fields();
      const auto* right_field_path = right_keys->field_path();
      std::vector<int> adjusted_field_indices(right_field_path->indices());
      adjusted_field_indices[0] -= num_left_fields;
      FieldPath adjusted_right_keys(adjusted_field_indices);
      acero::HashJoinNodeOptions join_options{{std::move(*left_keys)},
                                              {std::move(adjusted_right_keys)}};
      join_options.join_type = join_type;
      join_options.key_cmp = {join_key_cmp};
      acero::Declaration join_dec{"hashjoin", std::move(join_options)};
      join_dec.inputs.emplace_back(std::move(left.declaration));
      join_dec.inputs.emplace_back(std::move(right.declaration));

      DeclarationInfo join_declaration{std::move(join_dec), join_schema};

      return ProcessEmit(join, join_declaration, join_schema);
    }
    case substrait::Rel::RelTypeCase::kFetch: {
      const auto& fetch = rel.fetch();
      RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options));

      if (!fetch.has_input()) {
        return Status::Invalid("substrait::FetchRel with no input relation");
      }

      ARROW_ASSIGN_OR_RAISE(auto input,
                            FromProto(fetch.input(), ext_set, conversion_options));

      int64_t offset = fetch.offset();
      int64_t count = fetch.count();

      acero::Declaration fetch_dec{
          "fetch", {input.declaration}, acero::FetchNodeOptions(offset, count)};

      DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema};
      return ProcessEmit(fetch, std::move(fetch_declaration),
                         fetch_declaration.output_schema);
    }
    case substrait::Rel::RelTypeCase::kSort: {
      const auto& sort = rel.sort();
      RETURN_NOT_OK(CheckRelCommon(sort, conversion_options));

      if (!sort.has_input()) {
        return Status::Invalid("substrait::SortRel with no input relation");
      }

      ARROW_ASSIGN_OR_RAISE(auto input,
                            FromProto(sort.input(), ext_set, conversion_options));

      if (sort.sorts_size() == 0) {
        return Status::Invalid("substrait::SortRel with no sorts");
      }

      std::vector<compute::SortKey> sort_keys;
      sort_keys.reserve(sort.sorts_size());
      // Substrait allows null placement to differ for each field.  Acero expects it to
      // be consistent across all fields.  So we grab the null placement from the first
      // key and verify all other keys have the same null placement
      std::optional<SortBehavior> sample_sort_behavior;
      for (const auto& sort : sort.sorts()) {
        ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior,
                              SortBehavior::Make(sort.direction()));
        if (sample_sort_behavior) {
          if (sample_sort_behavior->null_placement != sort_behavior.null_placement) {
            return Status::NotImplemented(
                "substrait::SortRel with ordering with mixed null placement");
          }
        } else {
          sample_sort_behavior = sort_behavior;
        }
        if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) {
          return Status::NotImplemented("substrait::SortRel with custom sort function");
        }
        ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
                              FromProto(sort.expr(), ext_set, conversion_options));
        const FieldRef* field_ref = expr.field_ref();
        if (field_ref) {
          sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order));
        } else {
          return Status::Invalid("Sort key expressions must be a direct reference.");
        }
      }

      DCHECK(sample_sort_behavior.has_value());
      acero::Declaration sort_dec{
          "order_by",
          {input.declaration},
          acero::OrderByNodeOptions(compute::Ordering(
              std::move(sort_keys), sample_sort_behavior->null_placement))};

      DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema};
      return ProcessEmit(sort, std::move(sort_declaration),
                         sort_declaration.output_schema);
    }
    case substrait::Rel::RelTypeCase::kAggregate: {
      const auto& aggregate = rel.aggregate();
      RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options));

      if (!aggregate.has_input()) {
        return Status::Invalid("substrait::AggregateRel with no input relation");
      }

      ARROW_ASSIGN_OR_RAISE(auto input,
                            FromProto(aggregate.input(), ext_set, conversion_options));

      if (aggregate.groupings_size() > 1) {
        return Status::NotImplemented(
            "Grouping sets not supported.  AggregateRel::groupings may not have more "
            "than one item");
      }

      // prepare output schema from aggregates
      auto input_schema = input.output_schema;
      std::vector<FieldRef> keys;
      if (aggregate.groupings_size() > 0) {
        const substrait::AggregateRel::Grouping& group = aggregate.groupings(0);
        int grouping_expr_size = group.grouping_expressions_size();
        keys.reserve(grouping_expr_size);
        for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) {
          ARROW_ASSIGN_OR_RAISE(
              compute::Expression expr,
              FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options));
          const FieldRef* field_ref = expr.field_ref();
          if (field_ref) {
            keys.emplace_back(std::move(*field_ref));
          } else {
            return Status::Invalid(
                "The grouping expression for an aggregate must be a direct reference.");
          }
        }
      }

      const int measure_size = aggregate.measures_size();
      std::vector<compute::Aggregate> aggregates;
      aggregates.reserve(measure_size);
      for (int measure_id = 0; measure_id < measure_size; measure_id++) {
        const auto& agg_measure = aggregate.measures(measure_id);
        ARROW_ASSIGN_OR_RAISE(
            auto aggregate,
            internal::ParseAggregateMeasure(agg_measure, ext_set, conversion_options,
                                            /*is_hash=*/!keys.empty(), input_schema));
        aggregates.push_back(std::move(aggregate));
      }

      ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
                            acero::aggregate::MakeOutputSchema(
                                input_schema, keys, /*segment_keys=*/{}, aggregates));

      ARROW_ASSIGN_OR_RAISE(
          auto aggregate_declaration,
          internal::MakeAggregateDeclaration(std::move(input.declaration),
                                             aggregate_schema, std::move(aggregates),
                                             std::move(keys), /*segment_keys=*/{}));

      return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration),
                         std::move(aggregate_schema));
    }

    case substrait::Rel::RelTypeCase::kExtensionLeaf:
    case substrait::Rel::RelTypeCase::kExtensionSingle:
    case substrait::Rel::RelTypeCase::kExtensionMulti: {
      std::vector<DeclarationInfo> ext_rel_inputs;
      ARROW_ASSIGN_OR_RAISE(
          auto ext_decl_info,
          GetExtensionInfo(rel, ext_set, conversion_options, &ext_rel_inputs));
      auto ext_common_opt = GetExtensionRelCommon(rel);
      bool has_emit = ext_common_opt && ext_common_opt->emit_kind_case() ==
                                            substrait::RelCommon::EmitKindCase::kEmit;
      // Set up the emit order - an ordered list of indices that specifies an output
      // mapping as expected by Substrait. This is a sublist of [0..N), where N is the
      // total number of input fields across all inputs of the relation, that selects
      // from these input fields.
      if (has_emit) {
        std::vector<int> emit_order;
        // the emit order is defined in the Substrait plan - pick it up
        const auto& emit_info = ext_common_opt->emit();
        emit_order.reserve(emit_info.output_mapping_size());
        for (const auto& emit_idx : emit_info.output_mapping()) {
          emit_order.push_back(emit_idx);
        }
        return ProcessExtensionEmit(std::move(ext_decl_info), emit_order);
      } else {
        return ext_decl_info;
      }
    }

    case substrait::Rel::RelTypeCase::kSet: {
      const auto& set = rel.set();
      RETURN_NOT_OK(CheckRelCommon(set, conversion_options));

      if (set.inputs_size() < 2) {
        return Status::Invalid(
            "substrait::SetRel with inadequate number of input relations, ",
            set.inputs_size());
      }
      substrait::SetRel_SetOp op = set.op();
      // Note: at the moment Acero only supports UNION_ALL operation
      switch (op) {
        case substrait::SetRel::SET_OP_UNSPECIFIED:
        case substrait::SetRel::SET_OP_MINUS_PRIMARY:
        case substrait::SetRel::SET_OP_MINUS_MULTISET:
        case substrait::SetRel::SET_OP_INTERSECTION_PRIMARY:
        case substrait::SetRel::SET_OP_INTERSECTION_MULTISET:
        case substrait::SetRel::SET_OP_UNION_DISTINCT:
          return Status::NotImplemented(
              "NotImplemented union type : ",
              EnumToString(op, *substrait::SetRel_SetOp_descriptor()));
        case substrait::SetRel::SET_OP_UNION_ALL:
          break;
        default:
          return Status::Invalid("Unknown union type");
      }
      int input_size = set.inputs_size();
      acero::Declaration union_declr{"union", acero::ExecNodeOptions{}};
      std::shared_ptr<Schema> union_schema;
      for (int input_id = 0; input_id < input_size; input_id++) {
        ARROW_ASSIGN_OR_RAISE(
            auto input, FromProto(set.inputs(input_id), ext_set, conversion_options));
        union_declr.inputs.emplace_back(std::move(input.declaration));
        if (union_schema == nullptr) {
          union_schema = input.output_schema;
        }
      }

      auto set_declaration = DeclarationInfo{union_declr, union_schema};
      return ProcessEmit(set, set_declaration, union_schema);
    }

    default:
      break;
  }

  return Status::NotImplemented(
      "conversion to arrow::acero::Declaration from Substrait relation ",
      rel.DebugString());
}