c/validation/adbc_validation_util.h (384 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Utilities for testing with Nanoarrow.
#pragma once
#include <cstring>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include <arrow-adbc/adbc.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.h>
#include "common/utils.h"
namespace adbc_validation {
// ------------------------------------------------------------
// ADBC helpers
std::optional<std::string> ConnectionGetOption(struct AdbcConnection* connection,
std::string_view option,
struct AdbcError* error);
std::optional<std::string> StatementGetOption(struct AdbcStatement* statement,
std::string_view option,
struct AdbcError* error);
// ------------------------------------------------------------
// Helpers to print values
std::string StatusCodeToString(AdbcStatusCode code);
std::string ToString(struct AdbcError* error);
std::string ToString(struct ArrowError* error);
std::string ToString(struct ArrowArrayStream* stream);
// ------------------------------------------------------------
// Nanoarrow helpers
#define NULLABLE true
#define NOT_NULL false
// ------------------------------------------------------------
// Helper to manage C Data Interface/Nanoarrow resources with RAII
template <typename T>
struct Initializer {
static void Initialize(T* value) { memset(value, 0, sizeof(T)); }
};
template <typename T>
struct Releaser {
static void Release(T* value) {
if (value->release) {
value->release(value);
}
}
};
template <>
struct Initializer<struct ArrowBuffer> {
static void Initialize(struct ArrowBuffer* value) { ArrowBufferInit(value); }
};
template <>
struct Releaser<struct ArrowBuffer> {
static void Release(struct ArrowBuffer* buffer) { ArrowBufferReset(buffer); }
};
template <>
struct Releaser<struct ArrowArrayView> {
static void Release(struct ArrowArrayView* value) {
if (value->storage_type != NANOARROW_TYPE_UNINITIALIZED) {
ArrowArrayViewReset(value);
}
}
};
template <>
struct Releaser<struct AdbcConnection> {
static void Release(struct AdbcConnection* value) {
if (value->private_data) {
struct AdbcError error = {};
auto status = AdbcConnectionRelease(value, &error);
if (status != ADBC_STATUS_OK) {
FAIL() << StatusCodeToString(status) << ": " << ToString(&error);
}
}
}
};
template <>
struct Releaser<struct AdbcDatabase> {
static void Release(struct AdbcDatabase* value) {
if (value->private_data) {
struct AdbcError error = {};
auto status = AdbcDatabaseRelease(value, &error);
if (status != ADBC_STATUS_OK) {
FAIL() << StatusCodeToString(status) << ": " << ToString(&error);
}
}
}
};
template <>
struct Releaser<struct AdbcStatement> {
static void Release(struct AdbcStatement* value) {
if (value->private_data) {
struct AdbcError error = {};
auto status = AdbcStatementRelease(value, &error);
if (status != ADBC_STATUS_OK) {
FAIL() << StatusCodeToString(status) << ": " << ToString(&error);
}
}
}
};
template <typename Resource>
struct Handle {
Resource value;
Handle() { Initializer<Resource>::Initialize(&value); }
~Handle() { Releaser<Resource>::Release(&value); }
Resource* operator->() { return &value; }
};
// ------------------------------------------------------------
// GTest/GMock helpers
#define CHECK_OK(EXPR) \
do { \
if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \
return adbc_status; \
} \
} while (false)
/// \brief A GTest matcher for Nanoarrow/C Data Interface error codes.
class IsErrno {
public:
using is_gtest_matcher = void;
explicit IsErrno(int expected, struct ArrowArrayStream* stream,
struct ArrowError* error);
bool MatchAndExplain(int errcode, std::ostream* os) const;
void DescribeTo(std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;
private:
int expected_;
struct ArrowArrayStream* stream_;
struct ArrowError* error_;
};
::testing::Matcher<int> IsOkErrno();
::testing::Matcher<int> IsOkErrno(Handle<struct ArrowArrayStream>* stream);
::testing::Matcher<int> IsOkErrno(struct ArrowError* error);
/// \brief A GTest matcher for ADBC status codes
class IsAdbcStatusCode {
public:
using is_gtest_matcher = void;
explicit IsAdbcStatusCode(AdbcStatusCode expected, struct AdbcError* error);
bool MatchAndExplain(AdbcStatusCode actual, std::ostream* os) const;
void DescribeTo(std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;
private:
AdbcStatusCode expected_;
struct AdbcError* error_;
};
::testing::Matcher<AdbcStatusCode> IsOkStatus(struct AdbcError* error = nullptr);
::testing::Matcher<AdbcStatusCode> IsStatus(AdbcStatusCode code,
struct AdbcError* error = nullptr);
/// \brief Read an ArrowArrayStream with RAII safety
struct StreamReader {
Handle<struct ArrowArrayStream> stream;
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
Handle<struct ArrowArrayView> array_view;
std::vector<struct ArrowSchemaView> fields;
struct ArrowError na_error;
int64_t rows_affected = 0;
StreamReader() { std::memset(&na_error, 0, sizeof(na_error)); }
void GetSchema() {
ASSERT_NE(nullptr, stream->release);
ASSERT_THAT(stream->get_schema(&stream.value, &schema.value), IsOkErrno(&stream));
fields.resize(schema->n_children);
for (int64_t i = 0; i < schema->n_children; i++) {
ASSERT_THAT(ArrowSchemaViewInit(&fields[i], schema->children[i], &na_error),
IsOkErrno(&na_error));
}
}
void Next() { ASSERT_THAT(MaybeNext(), IsErrno(0, &stream.value, &na_error)); }
int MaybeNext() {
if (array->release) {
ArrowArrayViewReset(&array_view.value);
array->release(&array.value);
}
int err = stream->get_next(&stream.value, &array.value);
if (err != 0) return err;
if (array->release) {
err = ArrowArrayViewInitFromSchema(&array_view.value, &schema.value, &na_error);
if (err != 0) return err;
err = ArrowArrayViewSetArray(&array_view.value, &array.value, &na_error);
if (err != 0) return err;
}
return 0;
}
};
/// \brief Read an AdbcGetInfoData struct with RAII safety
struct GetObjectsReader {
explicit GetObjectsReader(struct ArrowArrayView* array_view) {
// TODO: this swallows any construction errors
get_objects_data_ = AdbcGetObjectsDataInit(array_view);
}
~GetObjectsReader() { AdbcGetObjectsDataDelete(get_objects_data_); }
struct AdbcGetObjectsData* operator*() { return get_objects_data_; }
struct AdbcGetObjectsData* operator->() { return get_objects_data_; }
private:
struct AdbcGetObjectsData* get_objects_data_;
};
struct SchemaField {
std::string name;
ArrowType type = NANOARROW_TYPE_UNINITIALIZED;
int32_t fixed_size = 0;
bool nullable = true;
std::vector<SchemaField> children;
SchemaField(std::string name, ArrowType type, bool nullable)
: name(std::move(name)), type(type), nullable(nullable) {}
SchemaField(std::string name, ArrowType type)
: SchemaField(std::move(name), type, /*nullable=*/true) {}
static SchemaField Nested(std::string name, ArrowType type,
std::vector<SchemaField> children) {
SchemaField out(name, type);
out.children = std::move(children);
return out;
}
static SchemaField FixedSize(std::string name, ArrowType type, int32_t fixed_size,
std::vector<SchemaField> children = {}) {
SchemaField out = Nested(name, type, std::move(children));
out.fixed_size = fixed_size;
return out;
}
};
/// \brief Make a schema from a vector of (name, type, nullable) tuples.
int MakeSchema(struct ArrowSchema* schema, const std::vector<SchemaField>& fields);
/// \brief Make an array from a column of C types.
template <typename T>
int MakeArray(struct ArrowArray* parent, struct ArrowArray* array,
const std::vector<std::optional<T>>& values) {
for (const auto& v : values) {
if (v.has_value()) {
if constexpr (std::is_same<T, bool>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value ||
std::is_same<T, int64_t>::value) {
CHECK_OK(ArrowArrayAppendInt(array, *v));
// XXX: cpplint gets weird here and thinks this is an unbraced if
} else if constexpr (std::is_same<T, // NOLINT(readability/braces)
uint8_t>::value ||
std::is_same<T, uint16_t>::value ||
std::is_same<T, uint32_t>::value ||
std::is_same<T, uint64_t>::value) {
CHECK_OK(ArrowArrayAppendUInt(array, *v));
} else if constexpr (std::is_same<T, float>::value || // NOLINT(readability/braces)
std::is_same<T, double>::value) {
CHECK_OK(ArrowArrayAppendDouble(array, *v));
} else if constexpr (std::is_same<T, std::string>::value) {
struct ArrowBufferView view;
view.data.as_char = v->c_str();
view.size_bytes = v->size();
CHECK_OK(ArrowArrayAppendBytes(array, view));
} else if constexpr (std::is_same<T, std::vector<std::byte>>::value) {
static_assert(std::is_same_v<uint8_t, unsigned char>);
struct ArrowBufferView view;
view.data.as_uint8 = reinterpret_cast<const uint8_t*>(v->data());
view.size_bytes = v->size();
CHECK_OK(ArrowArrayAppendBytes(array, view));
} else if constexpr (std::is_same<T, ArrowInterval*>::value) {
CHECK_OK(ArrowArrayAppendInterval(array, *v));
} else if constexpr (std::is_same<T, ArrowDecimal*>::value) {
CHECK_OK(ArrowArrayAppendDecimal(array, *v));
} else if constexpr (
// Possibly a more effective way to do this using template magic
// Not included but possible are the std::optional<> variants of this
std::is_same<T, std::vector<bool>>::value ||
std::is_same<T, std::vector<int8_t>>::value ||
std::is_same<T, std::vector<int16_t>>::value ||
std::is_same<T, std::vector<int32_t>>::value ||
std::is_same<T, std::vector<int64_t>>::value ||
std::is_same<T, std::vector<uint8_t>>::value ||
std::is_same<T, std::vector<uint16_t>>::value ||
std::is_same<T, std::vector<uint32_t>>::value ||
std::is_same<T, std::vector<uint64_t>>::value ||
std::is_same<T, std::vector<double>>::value ||
std::is_same<T, std::vector<float>>::value ||
std::is_same<T, std::vector<std::string>>::value ||
std::is_same<T, std::vector<std::vector<std::byte>>>::value) {
using child_t = typename T::value_type;
std::vector<std::optional<child_t>> value_nullable;
for (const auto& child_value : *v) {
value_nullable.push_back(child_value);
}
CHECK_OK(MakeArray(array, array->children[0], value_nullable));
CHECK_OK(ArrowArrayFinishElement(array));
} else {
static_assert(!sizeof(T), "Not yet implemented");
return ENOTSUP;
}
} else {
CHECK_OK(ArrowArrayAppendNull(array, 1));
}
}
return 0;
}
template <typename First>
int MakeBatchImpl(struct ArrowArray* batch, size_t i, struct ArrowError* error,
const std::vector<std::optional<First>>& first) {
return MakeArray<First>(batch, batch->children[i], first);
}
template <typename First, typename... Rest>
int MakeBatchImpl(struct ArrowArray* batch, size_t i, struct ArrowError* error,
const std::vector<std::optional<First>>& first,
const std::vector<std::optional<Rest>>&... rest) {
CHECK_OK(MakeArray<First>(batch, batch->children[i], first));
return MakeBatchImpl(batch, i + 1, error, rest...);
}
/// \brief Make a batch from columns of C types.
template <typename... T>
int MakeBatch(struct ArrowArray* batch, struct ArrowError* error,
const std::vector<std::optional<T>>&... columns) {
CHECK_OK(ArrowArrayStartAppending(batch));
CHECK_OK(MakeBatchImpl(batch, 0, error, columns...));
for (size_t i = 0; i < static_cast<size_t>(batch->n_children); i++) {
if (batch->length > 0 && batch->children[i]->length != batch->length) {
ADD_FAILURE() << "Column lengths are inconsistent: column " << i << " has length "
<< batch->children[i]->length;
return EINVAL;
}
batch->length = batch->children[i]->length;
}
return ArrowArrayFinishBuildingDefault(batch, error);
}
template <typename... T>
int MakeBatch(struct ArrowSchema* schema, struct ArrowArray* batch,
struct ArrowError* error, const std::vector<std::optional<T>>&... columns) {
CHECK_OK(ArrowArrayInitFromSchema(batch, schema, error));
return MakeBatch(batch, error, columns...);
}
/// \brief Make a stream from a list of batches.
void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema,
std::vector<struct ArrowArray> batches);
/// \brief Compare an array for equality against a vector of values.
template <typename T>
void CompareArray(struct ArrowArrayView* array,
const std::vector<std::optional<T>>& values, int64_t offset = 0,
int64_t length = -1) {
if (length == -1) {
length = array->length;
}
ASSERT_EQ(static_cast<int64_t>(values.size()), length);
int64_t i = offset;
for (const auto& v : values) {
SCOPED_TRACE("Array index " + std::to_string(i));
if (v.has_value()) {
ASSERT_FALSE(ArrowArrayViewIsNull(array, i));
if constexpr (std::is_same<T, float>::value || std::is_same<T, double>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(ArrowArrayViewGetDoubleUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, bool>::value ||
std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value ||
std::is_same<T, int64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(ArrowArrayViewGetIntUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, uint8_t>::value ||
std::is_same<T, uint16_t>::value ||
std::is_same<T, uint32_t>::value ||
std::is_same<T, uint64_t>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
ASSERT_EQ(ArrowArrayViewGetUIntUnsafe(array, i), *v);
} else if constexpr (std::is_same<T, std::string>::value) {
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
std::string str(view.data, view.size_bytes);
ASSERT_EQ(*v, str);
} else if constexpr (std::is_same<T, std::vector<std::byte>>::value) {
struct ArrowBufferView view = ArrowArrayViewGetBytesUnsafe(array, i);
ASSERT_EQ(v->size(), view.size_bytes);
for (int64_t i = 0; i < view.size_bytes; i++) {
ASSERT_EQ((*v)[i], std::byte{view.data.as_uint8[i]});
}
} else if constexpr (std::is_same<T, ArrowInterval*>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
struct ArrowInterval interval;
ArrowIntervalInit(&interval, ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO);
ArrowArrayViewGetIntervalUnsafe(array, i, &interval);
ASSERT_EQ(interval.months, (*v)->months);
ASSERT_EQ(interval.days, (*v)->days);
ASSERT_EQ(interval.ns, (*v)->ns);
} else if constexpr (
// Possibly a more effective way to do this using template magic
// Not included but possible are the std::optional<> variants of this
std::is_same<T, std::vector<bool>>::value ||
std::is_same<T, std::vector<int8_t>>::value ||
std::is_same<T, std::vector<int16_t>>::value ||
std::is_same<T, std::vector<int32_t>>::value ||
std::is_same<T, std::vector<int64_t>>::value ||
std::is_same<T, std::vector<uint8_t>>::value ||
std::is_same<T, std::vector<uint16_t>>::value ||
std::is_same<T, std::vector<uint32_t>>::value ||
std::is_same<T, std::vector<uint64_t>>::value ||
std::is_same<T, std::vector<double>>::value ||
std::is_same<T, std::vector<float>>::value ||
std::is_same<T, std::vector<std::string>>::value ||
std::is_same<T, std::vector<std::vector<std::byte>>>::value) {
using child_t = typename T::value_type;
std::vector<std::optional<child_t>> value_nullable;
for (const auto& child_value : *v) {
value_nullable.push_back(child_value);
}
SCOPED_TRACE("List item");
int64_t child_offset = ArrowArrayViewListChildOffset(array, i);
int64_t child_length = ArrowArrayViewListChildOffset(array, i + 1) - child_offset;
CompareArray<child_t>(array->children[0], value_nullable, child_offset,
child_length);
} else {
static_assert(!sizeof(T), "Not yet implemented");
}
} else {
ASSERT_TRUE(ArrowArrayViewIsNull(array, i));
}
i++;
}
}
/// \brief Compare a schema for equality against a vector of (name,
/// type, nullable) tuples.
void CompareSchema(struct ArrowSchema* schema, const std::vector<SchemaField>& fields);
/// \brief Helper method to get the vendor version of a driver
std::string GetDriverVendorVersion(struct AdbcConnection* connection);
} // namespace adbc_validation