Result FromProto()

in cpp/src/arrow/engine/substrait/expression_internal.cc [506:789]


Result<Datum> FromProto(const substrait::Expression::Literal& lit,
                        const ExtensionSet& ext_set,
                        const ConversionOptions& conversion_options) {
  if (lit.nullable() &&
      conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP) {
    // FIXME not sure how this field should be interpreted and there's no way to round
    // trip it through arrow
    return Status::Invalid(
        "Nullable Literals - Literal.nullable must be left at the default");
  }

  switch (lit.literal_type_case()) {
    case substrait::Expression::Literal::kBoolean:
      return Datum(lit.boolean());

    case substrait::Expression::Literal::kI8:
      return Datum(static_cast<int8_t>(lit.i8()));
    case substrait::Expression::Literal::kI16:
      return Datum(static_cast<int16_t>(lit.i16()));
    case substrait::Expression::Literal::kI32:
      return Datum(static_cast<int32_t>(lit.i32()));
    case substrait::Expression::Literal::kI64:
      return Datum(static_cast<int64_t>(lit.i64()));

    case substrait::Expression::Literal::kFp32:
      return Datum(lit.fp32());
    case substrait::Expression::Literal::kFp64:
      return Datum(lit.fp64());

    case substrait::Expression::Literal::kString:
      return Datum(lit.string());
    case substrait::Expression::Literal::kBinary:
      return Datum(BinaryScalar(lit.binary()));

      ARROW_SUPPRESS_DEPRECATION_WARNING
    case substrait::Expression::Literal::kTimestamp:
      return Datum(
          TimestampScalar(static_cast<int64_t>(lit.timestamp()), TimeUnit::MICRO));

    case substrait::Expression::Literal::kTimestampTz:
      return Datum(TimestampScalar(static_cast<int64_t>(lit.timestamp_tz()),
                                   TimeUnit::MICRO, TimestampTzTimezoneString()));
      ARROW_UNSUPPRESS_DEPRECATION_WARNING
    case substrait::Expression::Literal::kPrecisionTimestamp: {
      // https://github.com/substrait-io/substrait/issues/611
      // TODO(GH-40741) don't break, return precision timestamp
      break;
    }
    case substrait::Expression::Literal::kPrecisionTimestampTz: {
      // https://github.com/substrait-io/substrait/issues/611
      // TODO(GH-40741) don't break, return precision timestamp
      break;
    }
    case substrait::Expression::Literal::kDate:
      return Datum(Date32Scalar(lit.date()));
    case substrait::Expression::Literal::kTime:
      return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO));

    case substrait::Expression::Literal::kIntervalYearToMonth:
    case substrait::Expression::Literal::kIntervalDayToSecond: {
      Int32Builder builder;
      std::shared_ptr<DataType> type;
      if (lit.has_interval_year_to_month()) {
        RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().years()));
        RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().months()));
        type = interval_year();
      } else {
        RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().days()));
        RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().seconds()));
        type = interval_day();
      }
      ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish());
      return Datum(
          ExtensionScalar(FixedSizeListScalar(std::move(array)), std::move(type)));
    }

    case substrait::Expression::Literal::kUuid:
      return Datum(ExtensionScalar(FixedSizeBinaryScalar(lit.uuid()), uuid()));

    case substrait::Expression::Literal::kFixedChar:
      return Datum(
          ExtensionScalar(FixedSizeBinaryScalar(lit.fixed_char()),
                          fixed_char(static_cast<int32_t>(lit.fixed_char().size()))));

    case substrait::Expression::Literal::kVarChar:
      return Datum(
          ExtensionScalar(StringScalar(lit.var_char().value()),
                          varchar(static_cast<int32_t>(lit.var_char().length()))));

    case substrait::Expression::Literal::kFixedBinary:
      return Datum(FixedSizeBinaryScalar(lit.fixed_binary()));

    case substrait::Expression::Literal::kDecimal: {
      if (lit.decimal().value().size() != sizeof(Decimal128)) {
        return Status::Invalid("Decimal literal had ", lit.decimal().value().size(),
                               " bytes (expected ", sizeof(Decimal128), ")");
      }

      Decimal128 value;
      std::memcpy(value.mutable_native_endian_bytes(), lit.decimal().value().data(),
                  sizeof(Decimal128));
#if !ARROW_LITTLE_ENDIAN
      std::reverse(value.mutable_native_endian_bytes(),
                   value.mutable_native_endian_bytes() + sizeof(Decimal128));
#endif
      auto type = decimal128(lit.decimal().precision(), lit.decimal().scale());
      return Datum(Decimal128Scalar(value, std::move(type)));
    }

    case substrait::Expression::Literal::kStruct: {
      const auto& struct_ = lit.struct_();

      ScalarVector fields(struct_.fields_size());
      for (int i = 0; i < struct_.fields_size(); ++i) {
        ARROW_ASSIGN_OR_RAISE(auto field,
                              FromProto(struct_.fields(i), ext_set, conversion_options));
        DCHECK(field.is_scalar());
        fields[i] = field.scalar();
      }

      // Note that Substrait struct types don't have field names, but Arrow does, so we
      // just use empty strings for them.
      std::vector<std::string> field_names(fields.size(), "");

      ARROW_ASSIGN_OR_RAISE(
          auto scalar, StructScalar::Make(std::move(fields), std::move(field_names)));
      return Datum(std::move(scalar));
    }

    case substrait::Expression::Literal::kList: {
      const auto& list = lit.list();
      if (list.values_size() == 0) {
        return Status::Invalid(
            "substrait::Expression::Literal::List had no values; should have been an "
            "substrait::Expression::Literal::EmptyList");
      }

      std::shared_ptr<DataType> element_type;

      ScalarVector values(list.values_size());
      for (int i = 0; i < list.values_size(); ++i) {
        ARROW_ASSIGN_OR_RAISE(auto value,
                              FromProto(list.values(i), ext_set, conversion_options));
        DCHECK(value.is_scalar());
        values[i] = value.scalar();
        if (element_type) {
          if (!value.type()->Equals(*element_type)) {
            return Status::Invalid(
                list.DebugString(),
                " has a value whose type doesn't match the other list values");
          }
        } else {
          element_type = value.type();
        }
      }

      ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(element_type));
      RETURN_NOT_OK(builder->AppendScalars(values));
      ARROW_ASSIGN_OR_RAISE(auto arr, builder->Finish());
      return Datum(ListScalar(std::move(arr)));
    }

    case substrait::Expression::Literal::kMap: {
      const auto& map = lit.map();
      if (map.key_values_size() == 0) {
        return Status::Invalid(
            "substrait::Expression::Literal::Map had no values; should have been an "
            "substrait::Expression::Literal::EmptyMap");
      }

      std::shared_ptr<DataType> key_type, value_type;
      ScalarVector keys(map.key_values_size()), values(map.key_values_size());
      for (int i = 0; i < map.key_values_size(); ++i) {
        const auto& kv = map.key_values(i);

        static const std::array<char const*, 4> kMissing = {"key and value", "value",
                                                            "key", nullptr};
        if (auto missing = kMissing[kv.has_key() + kv.has_value() * 2]) {
          return Status::Invalid("While converting to MapScalar encountered missing ",
                                 missing, " in ", map.DebugString());
        }
        ARROW_ASSIGN_OR_RAISE(auto key, FromProto(kv.key(), ext_set, conversion_options));
        ARROW_ASSIGN_OR_RAISE(auto value,
                              FromProto(kv.value(), ext_set, conversion_options));

        DCHECK(key.is_scalar());
        DCHECK(value.is_scalar());

        keys[i] = key.scalar();
        values[i] = value.scalar();

        if (key_type) {
          if (!key.type()->Equals(*key_type)) {
            return Status::Invalid(map.DebugString(),
                                   " has a key whose type doesn't match key_type");
          }
        } else {
          key_type = value.type();
        }

        if (value_type) {
          if (!value.type()->Equals(*value_type)) {
            return Status::Invalid(map.DebugString(),
                                   " has a value whose type doesn't match value_type");
          }
        } else {
          value_type = value.type();
        }
      }

      ARROW_ASSIGN_OR_RAISE(auto key_builder, MakeBuilder(key_type));
      ARROW_ASSIGN_OR_RAISE(auto value_builder, MakeBuilder(value_type));
      RETURN_NOT_OK(key_builder->AppendScalars(keys));
      RETURN_NOT_OK(value_builder->AppendScalars(values));
      ARROW_ASSIGN_OR_RAISE(auto key_arr, key_builder->Finish());
      ARROW_ASSIGN_OR_RAISE(auto value_arr, value_builder->Finish());
      ARROW_ASSIGN_OR_RAISE(
          auto kv_arr,
          StructArray::Make(ArrayVector{std::move(key_arr), std::move(value_arr)},
                            std::vector<std::string>{"key", "value"}));
      return Datum(std::make_shared<MapScalar>(std::move(kv_arr)));
    }

    case substrait::Expression::Literal::kEmptyList: {
      ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.empty_list().type(),
                                                          ext_set, conversion_options));
      ARROW_ASSIGN_OR_RAISE(auto values, MakeEmptyArray(type_nullable.first));
      return ListScalar{std::move(values)};
    }

    case substrait::Expression::Literal::kEmptyMap: {
      ARROW_ASSIGN_OR_RAISE(
          auto key_type_nullable,
          FromProto(lit.empty_map().key(), ext_set, conversion_options));
      ARROW_ASSIGN_OR_RAISE(auto keys,
                            MakeEmptyArray(std::move(key_type_nullable.first)));

      ARROW_ASSIGN_OR_RAISE(
          auto value_type_nullable,
          FromProto(lit.empty_map().value(), ext_set, conversion_options));
      ARROW_ASSIGN_OR_RAISE(auto values,
                            MakeEmptyArray(std::move(value_type_nullable.first)));

      auto map_type = std::make_shared<MapType>(keys->type(), values->type());
      ARROW_ASSIGN_OR_RAISE(
          auto key_values,
          StructArray::Make(
              {std::move(keys), std::move(values)},
              checked_cast<const ListType&>(*map_type).value_type()->fields()));

      return MapScalar{std::move(key_values)};
    }

    case substrait::Expression::Literal::kNull: {
      ARROW_ASSIGN_OR_RAISE(auto type_nullable,
                            FromProto(lit.null(), ext_set, conversion_options));
      if (!type_nullable.second) {
        return Status::Invalid("Substrait null literal ", lit.DebugString(),
                               " is of non-nullable type");
      }

      return Datum(MakeNullScalar(std::move(type_nullable.first)));
    }

    case substrait::Expression::Literal::kUserDefined: {
      const auto& user_defined = lit.user_defined();
      ARROW_ASSIGN_OR_RAISE(auto type_record,
                            ext_set.DecodeType(user_defined.type_reference()));
      UserDefinedLiteralToArrow visitor{nullptr, &user_defined, &ext_set,
                                        conversion_options};
      ARROW_RETURN_NOT_OK((visitor)(*type_record.type));
      return Datum(std::move(visitor.scalar_));
    }

    case substrait::Expression::Literal::LITERAL_TYPE_NOT_SET:
      return Status::Invalid("substrait literal did not have any literal type set");

    default:
      break;
  }

  return Status::NotImplemented("conversion to arrow::Datum from Substrait literal `",
                                lit.DebugString(), "`");
}