c/validation/adbc_validation_util.cc (196 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "adbc_validation_util.h"
#include <adbc.h>
#include "adbc_validation.h"
namespace adbc_validation {
std::string StatusCodeToString(AdbcStatusCode code) {
#define CASE(CONSTANT) \
case ADBC_STATUS_##CONSTANT: \
return ADBCV_STRINGIFY_VALUE(ADBC_STATUS_##CONSTANT) " (" #CONSTANT ")";
switch (code) {
CASE(OK);
CASE(UNKNOWN);
CASE(NOT_IMPLEMENTED);
CASE(NOT_FOUND);
CASE(ALREADY_EXISTS);
CASE(INVALID_ARGUMENT);
CASE(INVALID_STATE);
CASE(INVALID_DATA);
CASE(INTEGRITY);
CASE(INTERNAL);
CASE(IO);
CASE(CANCELLED);
CASE(TIMEOUT);
CASE(UNAUTHENTICATED);
CASE(UNAUTHORIZED);
default:
return "(unknown code)";
}
#undef CASE
}
std::string ToString(struct AdbcError* error) {
if (error && error->message) {
std::string result = error->message;
error->release(error);
return result;
}
return "";
}
std::string ToString(struct ArrowError* error) { return error ? error->message : ""; }
std::string ToString(struct ArrowArrayStream* stream) {
if (stream && stream->get_last_error) {
const char* error = stream->get_last_error(stream);
if (error) return error;
}
return "";
}
IsErrno::IsErrno(int expected, struct ArrowArrayStream* stream, struct ArrowError* error)
: expected_(expected), stream_(stream), error_(error) {}
bool IsErrno::MatchAndExplain(int errcode, std::ostream* os) const {
if (errcode != expected_) {
if (os) {
*os << std::strerror(errcode);
if (stream_) *os << "\nError message: " << ToString(stream_);
if (error_) *os << "\nError message: " << ToString(error_);
}
return false;
}
return true;
}
void IsErrno::DescribeTo(std::ostream* os) const { *os << "is errno " << expected_; }
void IsErrno::DescribeNegationTo(std::ostream* os) const {
*os << "is not errno " << expected_;
}
::testing::Matcher<int> IsOkErrno() { return IsErrno(0, nullptr, nullptr); }
::testing::Matcher<int> IsOkErrno(Handle<struct ArrowArrayStream>* stream) {
return IsErrno(0, &stream->value, nullptr);
}
::testing::Matcher<int> IsOkErrno(struct ArrowError* error) {
return IsErrno(0, nullptr, error);
}
IsAdbcStatusCode::IsAdbcStatusCode(AdbcStatusCode expected, struct AdbcError* error)
: expected_(expected), error_(error) {}
bool IsAdbcStatusCode::MatchAndExplain(AdbcStatusCode actual, std::ostream* os) const {
if (actual != expected_) {
if (os) {
*os << StatusCodeToString(actual);
if (error_) {
if (error_->message) *os << "\nError message: " << error_->message;
if (error_->sqlstate[0]) *os << "\nSQLSTATE: " << error_->sqlstate;
if (error_->vendor_code) *os << "\nVendor code: " << error_->vendor_code;
if (error_->release) error_->release(error_);
}
}
return false;
}
return true;
}
void IsAdbcStatusCode::DescribeTo(std::ostream* os) const {
*os << "is " << StatusCodeToString(expected_);
}
void IsAdbcStatusCode::DescribeNegationTo(std::ostream* os) const {
*os << "is not " << StatusCodeToString(expected_);
}
::testing::Matcher<AdbcStatusCode> IsOkStatus(struct AdbcError* error) {
return IsStatus(ADBC_STATUS_OK, error);
}
::testing::Matcher<AdbcStatusCode> IsStatus(AdbcStatusCode code,
struct AdbcError* error) {
return IsAdbcStatusCode(code, error);
}
#define CHECK_ERRNO(EXPR) \
do { \
if (int adbcv_errno = (EXPR); adbcv_errno != 0) { \
return adbcv_errno; \
} \
} while (false);
int MakeSchema(struct ArrowSchema* schema, const std::vector<SchemaField>& fields) {
ArrowSchemaInit(schema);
CHECK_ERRNO(ArrowSchemaSetTypeStruct(schema, fields.size()));
size_t i = 0;
for (const SchemaField& field : fields) {
CHECK_ERRNO(ArrowSchemaSetType(schema->children[i], field.type));
CHECK_ERRNO(ArrowSchemaSetName(schema->children[i], field.name.c_str()));
if (!field.nullable) {
schema->children[i]->flags &= ~ARROW_FLAG_NULLABLE;
}
i++;
}
return 0;
}
#undef CHECK_ERRNO
class ConstantArrayStream {
public:
explicit ConstantArrayStream(struct ArrowSchema* schema,
std::vector<struct ArrowArray> batches)
: batches_(std::move(batches)), next_index_(0) {
schema_ = *schema;
std::memset(schema, 0, sizeof(*schema));
}
static const char* GetLastError(struct ArrowArrayStream* stream) { return nullptr; }
static int GetNext(struct ArrowArrayStream* stream, struct ArrowArray* out) {
if (!stream || !stream->private_data || !out) return EINVAL;
auto* self = reinterpret_cast<ConstantArrayStream*>(stream->private_data);
if (self->next_index_ >= self->batches_.size()) {
out->release = nullptr;
return 0;
}
*out = self->batches_[self->next_index_];
std::memset(&self->batches_[self->next_index_], 0, sizeof(struct ArrowArray));
self->next_index_++;
return 0;
}
static int GetSchema(struct ArrowArrayStream* stream, struct ArrowSchema* out) {
if (!stream || !stream->private_data || !out) return EINVAL;
auto* self = reinterpret_cast<ConstantArrayStream*>(stream->private_data);
return ArrowSchemaDeepCopy(&self->schema_, out);
}
static void Release(struct ArrowArrayStream* stream) {
if (!stream->private_data) return;
auto* self = reinterpret_cast<ConstantArrayStream*>(stream->private_data);
self->schema_.release(&self->schema_);
for (size_t i = 0; i < self->batches_.size(); i++) {
if (self->batches_[i].release) {
self->batches_[i].release(&self->batches_[i]);
}
}
delete self;
std::memset(stream, 0, sizeof(*stream));
}
private:
struct ArrowSchema schema_;
std::vector<struct ArrowArray> batches_;
size_t next_index_;
};
void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema,
std::vector<struct ArrowArray> batches) {
stream->get_last_error = &ConstantArrayStream::GetLastError;
stream->get_next = &ConstantArrayStream::GetNext;
stream->get_schema = &ConstantArrayStream::GetSchema;
stream->release = &ConstantArrayStream::Release;
stream->private_data = new ConstantArrayStream(schema, std::move(batches));
}
void CompareSchema(
struct ArrowSchema* schema,
const std::vector<std::tuple<std::optional<std::string>, ArrowType, bool>>& fields) {
struct ArrowError na_error;
struct ArrowSchemaView view;
ASSERT_THAT(ArrowSchemaViewInit(&view, schema, &na_error), IsOkErrno(&na_error));
ASSERT_THAT(view.type,
::testing::AnyOf(NANOARROW_TYPE_LIST, NANOARROW_TYPE_MAP,
NANOARROW_TYPE_STRUCT, NANOARROW_TYPE_DENSE_UNION));
ASSERT_EQ(fields.size(), schema->n_children);
for (int64_t i = 0; i < schema->n_children; i++) {
SCOPED_TRACE("Field " + std::to_string(i));
struct ArrowSchemaView field_view;
ASSERT_THAT(ArrowSchemaViewInit(&field_view, schema->children[i], &na_error),
IsOkErrno(&na_error));
ASSERT_EQ(std::get<1>(fields[i]), field_view.type);
ASSERT_EQ(std::get<2>(fields[i]),
(schema->children[i]->flags & ARROW_FLAG_NULLABLE) != 0)
<< "Nullability mismatch";
if (std::get<0>(fields[i]).has_value()) {
ASSERT_STRCASEEQ(std::get<0>(fields[i])->c_str(), schema->children[i]->name);
}
}
}
} // namespace adbc_validation