std::shared_ptr RandomArrayGenerator::ArrayOf()

in cpp/src/arrow/testing/random.cc [993:1351]


std::shared_ptr<Array> RandomArrayGenerator::ArrayOf(const Field& field, int64_t length,
                                                     int64_t alignment,
                                                     MemoryPool* memory_pool) {
#define VALIDATE_RANGE(PARAM, MIN, MAX)                                          \
  if (PARAM < MIN || PARAM > MAX) {                                              \
    ABORT_NOT_OK(Status::Invalid(field.ToString(), ": ", ARROW_STRINGIFY(PARAM), \
                                 " must be in [", MIN, ", ", MAX, " ] but got ", \
                                 PARAM));                                        \
  }
#define VALIDATE_MIN_MAX(MIN, MAX)                                                  \
  if (MIN > MAX) {                                                                  \
    ABORT_NOT_OK(                                                                   \
        Status::Invalid(field.ToString(), ": min ", MIN, " must be <= max ", MAX)); \
  }
#define GENERATE_INTEGRAL_CASE_VIEW(BASE_TYPE, VIEW_TYPE)                              \
  case VIEW_TYPE::type_id: {                                                           \
    const BASE_TYPE::c_type min_value = GetMetadata<BASE_TYPE::c_type>(                \
        field.metadata().get(), "min", std::numeric_limits<BASE_TYPE::c_type>::min()); \
    const BASE_TYPE::c_type max_value = GetMetadata<BASE_TYPE::c_type>(                \
        field.metadata().get(), "max", std::numeric_limits<BASE_TYPE::c_type>::max()); \
    VALIDATE_MIN_MAX(min_value, max_value);                                            \
    return *Numeric<BASE_TYPE>(length, min_value, max_value, null_probability,         \
                               alignment, memory_pool)                                 \
                ->View(field.type());                                                  \
  }
#define GENERATE_INTEGRAL_CASE(ARROW_TYPE) \
  GENERATE_INTEGRAL_CASE_VIEW(ARROW_TYPE, ARROW_TYPE)
#define GENERATE_FLOATING_CASE(ARROW_TYPE, GENERATOR_FUNC)                              \
  case ARROW_TYPE::type_id: {                                                           \
    const ARROW_TYPE::c_type min_value = GetMetadata<ARROW_TYPE::c_type>(               \
        field.metadata().get(), "min", std::numeric_limits<ARROW_TYPE::c_type>::min()); \
    const ARROW_TYPE::c_type max_value = GetMetadata<ARROW_TYPE::c_type>(               \
        field.metadata().get(), "max", std::numeric_limits<ARROW_TYPE::c_type>::max()); \
    const double nan_probability =                                                      \
        GetMetadata<double>(field.metadata().get(), "nan_probability", 0);              \
    VALIDATE_MIN_MAX(min_value, max_value);                                             \
    VALIDATE_RANGE(nan_probability, 0.0, 1.0);                                          \
    return GENERATOR_FUNC(length, min_value, max_value, null_probability,               \
                          nan_probability, alignment, memory_pool);                     \
  }

  // Don't use compute::Sum since that may not get built
#define GENERATE_LIST_CASE(ARRAY_TYPE)                                               \
  case ARRAY_TYPE::TypeClass::type_id: {                                             \
    const auto min_length = GetMetadata<ARRAY_TYPE::TypeClass::offset_type>(         \
        field.metadata().get(), "min_length", 0);                                    \
    const auto max_length = GetMetadata<ARRAY_TYPE::TypeClass::offset_type>(         \
        field.metadata().get(), "max_length", 20);                                   \
    const auto lengths = internal::checked_pointer_cast<                             \
        CTypeTraits<ARRAY_TYPE::TypeClass::offset_type>::ArrayType>(                 \
        Numeric<CTypeTraits<ARRAY_TYPE::TypeClass::offset_type>::ArrowType>(         \
            length, min_length, max_length, null_probability));                      \
    int64_t values_length = 0;                                                       \
    for (const auto& length : *lengths) {                                            \
      if (length.has_value()) values_length += *length;                              \
    }                                                                                \
    const auto force_empty_nulls =                                                   \
        GetMetadata<bool>(field.metadata().get(), "force_empty_nulls", false);       \
    const auto values =                                                              \
        ArrayOf(*internal::checked_pointer_cast<ARRAY_TYPE::TypeClass>(field.type()) \
                     ->value_field(),                                                \
                values_length, alignment, memory_pool);                              \
    const auto offsets = OffsetsFromLengthsArray(lengths.get(), force_empty_nulls,   \
                                                 alignment, memory_pool);            \
    return *ARRAY_TYPE::FromArrays(field.type(), *offsets, *values);                 \
  }

#define GENERATE_LIST_VIEW_CASE(ARRAY_TYPE)                                           \
  case ARRAY_TYPE::TypeClass::type_id: {                                              \
    return *ArrayOfListView<ARRAY_TYPE>(*this, field, length, alignment, memory_pool, \
                                        null_probability);                            \
  }

  const double null_probability =
      field.nullable()
          ? GetMetadata<double>(field.metadata().get(), "null_probability", 0.01)
          : 0.0;
  VALIDATE_RANGE(null_probability, 0.0, 1.0);
  switch (field.type()->id()) {
    case Type::type::NA: {
      return std::make_shared<NullArray>(length);
    }

    case Type::type::BOOL: {
      const double true_probability =
          GetMetadata<double>(field.metadata().get(), "true_probability", 0.5);
      return Boolean(length, true_probability, null_probability, alignment, memory_pool);
    }

      GENERATE_INTEGRAL_CASE(UInt8Type);
      GENERATE_INTEGRAL_CASE(Int8Type);
      GENERATE_INTEGRAL_CASE(UInt16Type);
      GENERATE_INTEGRAL_CASE(Int16Type);
      GENERATE_INTEGRAL_CASE(UInt32Type);
      GENERATE_INTEGRAL_CASE(Int32Type);
      GENERATE_INTEGRAL_CASE(UInt64Type);
      GENERATE_INTEGRAL_CASE(Int64Type);
      GENERATE_INTEGRAL_CASE_VIEW(Int16Type, HalfFloatType);
      GENERATE_FLOATING_CASE(FloatType, Float32);
      GENERATE_FLOATING_CASE(DoubleType, Float64);

    case Type::type::STRING:
    case Type::type::BINARY: {
      const auto min_length =
          GetMetadata<int32_t>(field.metadata().get(), "min_length", 0);
      const auto max_length =
          GetMetadata<int32_t>(field.metadata().get(), "max_length", 20);
      const auto unique_values =
          GetMetadata<int32_t>(field.metadata().get(), "unique", -1);
      if (unique_values > 0) {
        return *StringWithRepeats(length, unique_values, min_length, max_length,
                                  null_probability, alignment, memory_pool)
                    ->View(field.type());
      }
      return *String(length, min_length, max_length, null_probability, alignment,
                     memory_pool)
                  ->View(field.type());
    }

    case Type::type::STRING_VIEW:
    case Type::type::BINARY_VIEW: {
      const auto min_length =
          GetMetadata<int32_t>(field.metadata().get(), "min_length", 0);
      const auto max_length =
          GetMetadata<int32_t>(field.metadata().get(), "max_length", 20);
      std::optional<int64_t> max_data_buffer_length =
          GetMetadata<int64_t>(field.metadata().get(), "max_data_buffer_length", 0);
      if (*max_data_buffer_length == 0) {
        *max_data_buffer_length = {};
      }

      return StringView(length, min_length, max_length, null_probability,
                        max_data_buffer_length, alignment)
          ->View(field.type())
          .ValueOrDie();
    }

    case Type::type::DECIMAL32:
      return Decimal32(field.type(), length, null_probability, alignment, memory_pool);

    case Type::type::DECIMAL64:
      return Decimal64(field.type(), length, null_probability, alignment, memory_pool);

    case Type::type::DECIMAL128:
      return Decimal128(field.type(), length, null_probability, alignment, memory_pool);

    case Type::type::DECIMAL256:
      return Decimal256(field.type(), length, null_probability, alignment, memory_pool);

    case Type::type::FIXED_SIZE_BINARY: {
      auto byte_width =
          internal::checked_pointer_cast<FixedSizeBinaryType>(field.type())->byte_width();
      return *FixedSizeBinary(length, byte_width, null_probability,
                              /*min_byte=*/static_cast<uint8_t>('A'),
                              /*min_byte=*/static_cast<uint8_t>('z'), alignment,
                              memory_pool)
                  ->View(field.type());
    }

      GENERATE_INTEGRAL_CASE_VIEW(Int32Type, Date32Type);
      GENERATE_INTEGRAL_CASE_VIEW(Int64Type, TimestampType);
      GENERATE_INTEGRAL_CASE_VIEW(Int32Type, MonthIntervalType);

    case Type::type::DATE64: {
      using c_type = typename Date64Type::c_type;
      constexpr c_type kFullDayMillis = 1000 * 60 * 60 * 24;
      constexpr c_type kDefaultMin = std::numeric_limits<c_type>::min() / kFullDayMillis;
      constexpr c_type kDefaultMax = std::numeric_limits<c_type>::max() / kFullDayMillis;

      const c_type min_value =
          GetMetadata<c_type>(field.metadata().get(), "min", kDefaultMin);
      const c_type max_value =
          GetMetadata<c_type>(field.metadata().get(), "max", kDefaultMax);

      return *Numeric<Date64Type>(length, min_value, max_value, null_probability,
                                  alignment, memory_pool)
                  ->View(field.type());
    }

    case Type::type::TIME32: {
      TimeUnit::type unit =
          internal::checked_pointer_cast<Time32Type>(field.type())->unit();
      using c_type = typename Time32Type::c_type;
      const c_type min_value = 0;
      const c_type max_value =
          (unit == TimeUnit::SECOND) ? (60 * 60 * 24 - 1) : (1000 * 60 * 60 * 24 - 1);

      return *Numeric<Int32Type>(length, min_value, max_value, null_probability,
                                 alignment, memory_pool)
                  ->View(field.type());
    }

    case Type::type::TIME64: {
      TimeUnit::type unit =
          internal::checked_pointer_cast<Time64Type>(field.type())->unit();
      using c_type = typename Time64Type::c_type;
      const c_type min_value = 0;
      const c_type max_value = (unit == TimeUnit::MICRO)
                                   ? (1000000LL * 60 * 60 * 24 - 1)
                                   : (1000000000LL * 60 * 60 * 24 - 1);

      return *Numeric<Int64Type>(length, min_value, max_value, null_probability,
                                 alignment, memory_pool)
                  ->View(field.type());
    }

      // This isn't as flexible as it could be, but the array-of-structs layout of this
      // type means it's not a (useful) composition of other generators
      GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DayTimeIntervalType);
    case Type::type::INTERVAL_MONTH_DAY_NANO: {
      return *FixedSizeBinary(length, /*byte_width=*/16, null_probability,
                              /*min_byte=*/static_cast<uint8_t>('A'),
                              /*min_byte=*/static_cast<uint8_t>('z'), alignment,
                              memory_pool)
                  ->View(month_day_nano_interval());
    }

      GENERATE_LIST_CASE(ListArray);
      GENERATE_LIST_VIEW_CASE(ListViewArray);

    case Type::type::STRUCT: {
      ArrayVector child_arrays(field.type()->num_fields());
      FieldVector child_fields(field.type()->num_fields());
      for (int i = 0; i < field.type()->num_fields(); i++) {
        const auto& child_field = field.type()->field(i);
        child_arrays[i] = ArrayOf(*child_field, length, alignment, memory_pool);
        child_fields[i] = child_field;
      }
      return *StructArray::Make(
          child_arrays, child_fields,
          NullBitmap(length, null_probability, alignment, memory_pool));
    }

    case Type::type::RUN_END_ENCODED: {
      auto* ree_type = internal::checked_cast<RunEndEncodedType*>(field.type().get());
      return RunEndEncoded(ree_type->value_type(), length, null_probability);
    }

    case Type::type::SPARSE_UNION:
    case Type::type::DENSE_UNION: {
      ArrayVector child_arrays(field.type()->num_fields());
      for (int i = 0; i < field.type()->num_fields(); ++i) {
        const auto& child_field = field.type()->field(i);
        child_arrays[i] = ArrayOf(*child_field, length, alignment, memory_pool);
      }
      auto array = field.type()->id() == Type::type::SPARSE_UNION
                       ? SparseUnion(child_arrays, length, alignment, memory_pool)
                       : DenseUnion(child_arrays, length, alignment, memory_pool);

      const auto& type_codes = checked_cast<const UnionType&>(*field.type()).type_codes();
      const auto& default_type_codes =
          checked_cast<const UnionType&>(*array->type()).type_codes();

      if (type_codes != default_type_codes) {
        // map to the type ids specified by the UnionType
        auto* type_ids =
            reinterpret_cast<int8_t*>(array->data()->buffers[1]->mutable_data());
        for (int64_t i = 0; i != array->length(); ++i) {
          type_ids[i] = type_codes[type_ids[i]];
        }
      }
      return *array->View(field.type());  // view gets the field names right for us
    }

    case Type::type::DICTIONARY: {
      const auto values_length =
          GetMetadata<int64_t>(field.metadata().get(), "values", 4);
      auto dict_type = internal::checked_pointer_cast<DictionaryType>(field.type());
      // TODO: no way to control generation of dictionary
      auto values =
          ArrayOf(*arrow::field("temporary", dict_type->value_type(), /*nullable=*/false),
                  values_length, alignment, memory_pool);
      auto merged = field.metadata() ? field.metadata() : key_value_metadata({}, {});
      if (merged->Contains("min"))
        ABORT_NOT_OK(Status::Invalid(field.ToString(), ": cannot specify min"));
      if (merged->Contains("max"))
        ABORT_NOT_OK(Status::Invalid(field.ToString(), ": cannot specify max"));
      merged = merged->Merge(
          *key_value_metadata({{"min", "0"}, {"max", ToChars(values_length - 1)}}));
      auto indices = ArrayOf(
          *arrow::field("temporary", dict_type->index_type(), field.nullable(), merged),
          length, alignment, memory_pool);
      return *DictionaryArray::FromArrays(field.type(), indices, values);
    }

    case Type::type::MAP: {
      const auto values_length = GetMetadata<int32_t>(field.metadata().get(), "values",
                                                      static_cast<int32_t>(length));
      const auto force_empty_nulls =
          GetMetadata<bool>(field.metadata().get(), "force_empty_nulls", false);
      auto map_type = internal::checked_pointer_cast<MapType>(field.type());
      auto keys = ArrayOf(*map_type->key_field(), values_length, alignment, memory_pool);
      auto items =
          ArrayOf(*map_type->item_field(), values_length, alignment, memory_pool);
      // need N + 1 offsets to have N values
      auto offsets = Offsets(length + 1, 0, values_length, null_probability,
                             force_empty_nulls, alignment, memory_pool);
      return *MapArray::FromArrays(map_type, offsets, keys, items);
    }

    case Type::type::EXTENSION:
      if (GetMetadata<bool>(field.metadata().get(), "extension_allow_random_storage",
                            false)) {
        const auto& ext_type = checked_cast<const ExtensionType&>(*field.type());
        auto storage = ArrayOf(*field.WithType(ext_type.storage_type()), length,
                               alignment, memory_pool);
        return ExtensionType::WrapArray(field.type(), storage);
      }
      // We don't have explicit permission to generate random storage; bail rather than
      // silently risk breaking extension invariants
      break;

    case Type::type::FIXED_SIZE_LIST: {
      auto list_type = internal::checked_pointer_cast<FixedSizeListType>(field.type());
      const int64_t values_length = list_type->list_size() * length;
      auto values =
          ArrayOf(*list_type->value_field(), values_length, alignment, memory_pool);
      auto null_bitmap = NullBitmap(length, null_probability, alignment, memory_pool);
      return std::make_shared<FixedSizeListArray>(list_type, length, values, null_bitmap);
    }

      GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DurationType);

    case Type::type::LARGE_STRING:
    case Type::type::LARGE_BINARY: {
      const auto min_length =
          GetMetadata<int32_t>(field.metadata().get(), "min_length", 0);
      const auto max_length =
          GetMetadata<int32_t>(field.metadata().get(), "max_length", 20);
      const auto unique_values =
          GetMetadata<int32_t>(field.metadata().get(), "unique", -1);
      if (unique_values > 0) {
        ABORT_NOT_OK(
            Status::NotImplemented("Generating random array with repeated values for "
                                   "large string/large binary types"));
      }
      return *LargeString(length, min_length, max_length, null_probability, alignment,
                          memory_pool)
                  ->View(field.type());
    }

      GENERATE_LIST_CASE(LargeListArray);
      GENERATE_LIST_VIEW_CASE(LargeListViewArray);

    default:
      break;
  }
#undef GENERATE_INTEGRAL_CASE_VIEW
#undef GENERATE_INTEGRAL_CASE
#undef GENERATE_FLOATING_CASE
#undef GENERATE_LIST_CASE
#undef GENERATE_LIST_VIEW_CASE
#undef VALIDATE_RANGE
#undef VALIDATE_MIN_MAX

  ABORT_NOT_OK(
      Status::NotImplemented("Generating random array for field ", field.ToString()));
  return nullptr;
}