c/driver/framework/base_driver.h (879 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <charconv>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
#include <arrow-adbc/adbc.h>
#include "driver/framework/status.h"
/// \file base.h ADBC Driver Framework
///
/// A base implementation of an ADBC driver that allows easier driver
/// development by overriding functions. Databases, connections, and
/// statements can be defined by subclassing the [CRTP][crtp] base classes.
///
/// Generally, base classes provide a set of functions that correspond to the
/// ADBC functions. These should not be directly overridden, as they provide
/// the core logic and argument checking/error handling. Instead, override
/// the -Impl functions that are also exposed by base classes.
///
/// [crtp]: https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
namespace adbc::driver {
/// \brief The state of a database/connection/statement.
enum class LifecycleState {
/// \brief New has been called but not Init.
kUninitialized,
/// \brief Init has been called.
kInitialized,
};
/// \brief A typed option value wrapper. It currently does not attempt
/// conversion (i.e., getting a double option as a string).
class Option {
public:
/// \brief The option is unset.
struct Unset {};
/// \brief The possible values of an option.
using Value = std::variant<Unset, std::string, std::vector<uint8_t>, int64_t, double>;
Option() : value_(Unset{}) {}
/// \brief Construct an option from a C string.
/// NULL strings are treated as unset.
explicit Option(const char* value)
: value_(value ? Value(std::string(value)) : Value{Unset{}}) {}
explicit Option(std::string value) : value_(std::move(value)) {}
explicit Option(std::vector<uint8_t> value) : value_(std::move(value)) {}
explicit Option(double value) : value_(value) {}
explicit Option(int64_t value) : value_(value) {}
const Value& value() const& { return value_; }
Value& value() && { return value_; }
/// \brief Check whether this option is set.
bool has_value() const { return !std::holds_alternative<Unset>(value_); }
/// \brief Try to parse a string value as a boolean.
Result<bool> AsBool() const {
return std::visit(
[&](auto&& value) -> Result<bool> {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string>) {
if (value == ADBC_OPTION_VALUE_ENABLED) {
return true;
} else if (value == ADBC_OPTION_VALUE_DISABLED) {
return false;
}
}
return status::InvalidArgument("Invalid boolean value ", this->Format());
},
value_);
}
/// \brief Try to parse a string or integer value as an integer.
Result<int64_t> AsInt() const {
return std::visit(
[&](auto&& value) -> Result<int64_t> {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, int64_t>) {
return value;
} else if constexpr (std::is_same_v<T, std::string>) {
int64_t parsed = 0;
auto begin = value.data();
auto end = value.data() + value.size();
auto result = std::from_chars(begin, end, parsed);
if (result.ec != std::errc()) {
return status::InvalidArgument("Invalid integer value '", value,
"': not an integer", value);
} else if (result.ptr != end) {
return status::InvalidArgument("Invalid integer value '", value,
"': trailing data", value);
}
return parsed;
} else {
return status::InvalidArgument("Invalid integer value ", this->Format());
}
},
value_);
}
/// \brief Get the value if it is a string.
Result<std::string_view> AsString() const {
return std::visit(
[&](auto&& value) -> Result<std::string_view> {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string>) {
return value;
} else {
return status::InvalidArgument("Invalid string value ", this->Format());
}
},
value_);
}
/// \brief Provide a human-readable summary of the value
std::string Format() const {
return std::visit(
[&](auto&& value) -> std::string {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, adbc::driver::Option::Unset>) {
return "(NULL)";
} else if constexpr (std::is_same_v<T, std::string>) {
return std::string("'") + value + "'";
} else if constexpr (std::is_same_v<T, std::vector<uint8_t>>) {
return std::string("(") + std::to_string(value.size()) + " bytes)";
} else {
return std::to_string(value);
}
},
value_);
}
private:
Value value_;
// Methods used by trampolines to export option values in C below
friend class ObjectBase;
AdbcStatusCode CGet(char* out, size_t* length, AdbcError* error) const {
{
if (!length || (!out && *length > 0)) {
return status::InvalidArgument("Must provide both out and length to GetOption")
.ToAdbc(error);
}
return std::visit(
[&](auto&& value) -> AdbcStatusCode {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string>) {
size_t value_size_with_terminator = value.size() + 1;
if (*length >= value_size_with_terminator) {
std::memcpy(out, value.data(), value.size());
out[value.size()] = 0;
}
*length = value_size_with_terminator;
return ADBC_STATUS_OK;
} else if constexpr (std::is_same_v<T, Unset>) {
return status::NotFound("Unknown option").ToAdbc(error);
} else {
return status::NotFound("Option value is not a string").ToAdbc(error);
}
},
value_);
}
}
AdbcStatusCode CGet(uint8_t* out, size_t* length, AdbcError* error) const {
if (!length || (!out && *length > 0)) {
return status::InvalidArgument("Must provide both out and length to GetOption")
.ToAdbc(error);
}
return std::visit(
[&](auto&& value) -> AdbcStatusCode {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string> ||
std::is_same_v<T, std::vector<uint8_t>>) {
if (*length >= value.size()) {
std::memcpy(out, value.data(), value.size());
}
*length = value.size();
return ADBC_STATUS_OK;
} else if constexpr (std::is_same_v<T, Unset>) {
return status::NotFound("Unknown option").ToAdbc(error);
} else {
return status::NotFound("Option value is not a bytestring").ToAdbc(error);
}
},
value_);
}
AdbcStatusCode CGet(int64_t* out, AdbcError* error) const {
{
if (!out) {
return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error);
}
return std::visit(
[&](auto&& value) -> AdbcStatusCode {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, int64_t>) {
*out = value;
return ADBC_STATUS_OK;
} else if constexpr (std::is_same_v<T, Unset>) {
return status::NotFound("Unknown option").ToAdbc(error);
} else {
return status::NotFound("Option value is not an integer").ToAdbc(error);
}
},
value_);
}
}
AdbcStatusCode CGet(double* out, AdbcError* error) const {
if (!out) {
return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error);
}
return std::visit(
[&](auto&& value) -> AdbcStatusCode {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, double> || std::is_same_v<T, int64_t>) {
*out = static_cast<double>(value);
return ADBC_STATUS_OK;
} else if constexpr (std::is_same_v<T, Unset>) {
return status::NotFound("Unknown option").ToAdbc(error);
} else {
return status::NotFound("Option value is not a double").ToAdbc(error);
}
},
value_);
}
};
/// \brief Base class for private_data of AdbcDatabase, AdbcConnection, and
/// AdbcStatement.
///
/// This class handles option setting and getting.
class ObjectBase {
public:
ObjectBase() = default;
virtual ~ObjectBase() = default;
// Called After zero or more SetOption() calls. The parent is the
// private_data of the AdbcDatabase, or AdbcConnection when initializing a
// subclass of ConnectionObjectBase, and StatementObjectBase (respectively),
// or otherwise nullptr. For example, if you have defined
// Driver<MyDatabase, MyConnection, MyStatement>, you can
// reinterpret_cast<MyDatabase>(parent) in MyConnection::Init().
/// \brief Initialize the object.
///
/// Called after 0 or more SetOption calls. Generally, you won't need to
/// override this directly. Instead, use the typed InitImpl provided by
/// Database/Connection/Statement.
///
/// \param[in] parent A pointer to the AdbcDatabase or AdbcConnection
/// implementation as appropriate, or nullptr.
virtual AdbcStatusCode Init(void* parent, AdbcError* error) {
lifecycle_state_ = LifecycleState::kInitialized;
return ADBC_STATUS_OK;
}
/// \brief Finalize the object.
///
/// This can be used to return an error if the object is not in a valid
/// state (e.g. prevent closing a connection with open statements) or to
/// clean up resources when resource cleanup could fail. Infallible
/// resource cleanup (e.g. releasing memory) should generally be handled in
/// the destructor.
///
/// Generally, you won't need to override this directly. Instead, use the
/// typed ReleaseImpl provided by Database/Connection/Statement.
virtual AdbcStatusCode Release(AdbcError* error) { return ADBC_STATUS_OK; }
/// \brief Get an option value.
virtual Result<Option> GetOption(std::string_view key) {
Option option(nullptr);
return option;
}
/// \brief Set an option value.
virtual AdbcStatusCode SetOption(std::string_view key, Option value, AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
protected:
LifecycleState lifecycle_state_;
private:
// Let the Driver use these to expose C callables wrapping option setters/getters
template <typename DatabaseT, typename ConnectionT, typename StatementT>
friend class Driver;
template <typename T>
AdbcStatusCode CSetOption(const char* key, T value, AdbcError* error) {
Option option(value);
return SetOption(key, std::move(option), error);
}
AdbcStatusCode CSetOptionBytes(const char* key, const uint8_t* value, size_t length,
AdbcError* error) {
std::vector<uint8_t> cppvalue(value, value + length);
Option option(std::move(cppvalue));
return SetOption(key, std::move(option), error);
}
template <typename T>
AdbcStatusCode CGetOptionStringLike(const char* key, T* value, size_t* length,
AdbcError* error) {
RAISE_RESULT(error, auto option, GetOption(key));
return option.CGet(value, length, error);
}
template <typename T>
AdbcStatusCode CGetOptionNumeric(const char* key, T* value, AdbcError* error) {
RAISE_RESULT(error, auto option, GetOption(key));
return option.CGet(value, error);
}
};
/// Helper for below: given the ADBC type, pick the right driver type.
template <typename DatabaseT, typename ConnectionT, typename StatementT, typename T>
struct ResolveObjectTImpl {};
template <typename DatabaseT, typename ConnectionT, typename StatementT>
struct ResolveObjectTImpl<DatabaseT, ConnectionT, StatementT, struct AdbcDatabase> {
using type = DatabaseT;
};
template <typename DatabaseT, typename ConnectionT, typename StatementT>
struct ResolveObjectTImpl<DatabaseT, ConnectionT, StatementT, struct AdbcConnection> {
using type = ConnectionT;
};
template <typename DatabaseT, typename ConnectionT, typename StatementT>
struct ResolveObjectTImpl<DatabaseT, ConnectionT, StatementT, struct AdbcStatement> {
using type = StatementT;
};
/// Helper for below: given the ADBC type, pick the right driver type.
template <typename DatabaseT, typename ConnectionT, typename StatementT, typename T>
using ResolveObjectT =
typename ResolveObjectTImpl<DatabaseT, ConnectionT, StatementT, T>::type;
// Driver authors can declare a template specialization of the Driver class
// and use it to provide their driver init function. It is possible, but
// rarely useful, to subclass a driver.
template <typename DatabaseT, typename ConnectionT, typename StatementT>
class Driver {
public:
static AdbcStatusCode Init(int version, void* raw_driver, AdbcError* error) {
if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
auto* driver = reinterpret_cast<AdbcDriver*>(raw_driver);
if (version >= ADBC_VERSION_1_1_0) {
std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE);
driver->ErrorGetDetailCount = &CErrorGetDetailCount;
driver->ErrorGetDetail = &CErrorGetDetail;
driver->DatabaseGetOption = &CGetOption<AdbcDatabase>;
driver->DatabaseGetOptionBytes = &CGetOptionBytes<AdbcDatabase>;
driver->DatabaseGetOptionInt = &CGetOptionInt<AdbcDatabase>;
driver->DatabaseGetOptionDouble = &CGetOptionDouble<AdbcDatabase>;
driver->DatabaseSetOptionBytes = &CSetOptionBytes<AdbcDatabase>;
driver->DatabaseSetOptionInt = &CSetOptionInt<AdbcDatabase>;
driver->DatabaseSetOptionDouble = &CSetOptionDouble<AdbcDatabase>;
driver->ConnectionCancel = &CConnectionCancel;
driver->ConnectionGetOption = &CGetOption<AdbcConnection>;
driver->ConnectionGetOptionBytes = &CGetOptionBytes<AdbcConnection>;
driver->ConnectionGetOptionInt = &CGetOptionInt<AdbcConnection>;
driver->ConnectionGetOptionDouble = &CGetOptionDouble<AdbcConnection>;
driver->ConnectionGetStatistics = &CConnectionGetStatistics;
driver->ConnectionGetStatisticNames = &CConnectionGetStatisticNames;
driver->ConnectionSetOptionBytes = &CSetOptionBytes<AdbcConnection>;
driver->ConnectionSetOptionInt = &CSetOptionInt<AdbcConnection>;
driver->ConnectionSetOptionDouble = &CSetOptionDouble<AdbcConnection>;
driver->StatementCancel = &CStatementCancel;
driver->StatementExecuteSchema = &CStatementExecuteSchema;
driver->StatementGetOption = &CGetOption<AdbcStatement>;
driver->StatementGetOptionBytes = &CGetOptionBytes<AdbcStatement>;
driver->StatementGetOptionInt = &CGetOptionInt<AdbcStatement>;
driver->StatementGetOptionDouble = &CGetOptionDouble<AdbcStatement>;
driver->StatementSetOptionBytes = &CSetOptionBytes<AdbcStatement>;
driver->StatementSetOptionInt = &CSetOptionInt<AdbcStatement>;
driver->StatementSetOptionDouble = &CSetOptionDouble<AdbcStatement>;
} else {
std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE);
}
driver->private_data = new Driver();
driver->release = &CDriverRelease;
driver->DatabaseInit = &CDatabaseInit;
driver->DatabaseNew = &CNew<AdbcDatabase>;
driver->DatabaseRelease = &CRelease<AdbcDatabase>;
driver->DatabaseSetOption = &CSetOption<AdbcDatabase>;
driver->ConnectionCommit = &CConnectionCommit;
driver->ConnectionGetInfo = &CConnectionGetInfo;
driver->ConnectionGetObjects = &CConnectionGetObjects;
driver->ConnectionGetTableSchema = &CConnectionGetTableSchema;
driver->ConnectionGetTableTypes = &CConnectionGetTableTypes;
driver->ConnectionInit = &CConnectionInit;
driver->ConnectionNew = &CNew<AdbcConnection>;
driver->ConnectionRelease = &CRelease<AdbcConnection>;
driver->ConnectionReadPartition = &CConnectionReadPartition;
driver->ConnectionRollback = &CConnectionRollback;
driver->ConnectionSetOption = &CSetOption<AdbcConnection>;
driver->StatementBind = &CStatementBind;
driver->StatementBindStream = &CStatementBindStream;
driver->StatementExecutePartitions = &CStatementExecutePartitions;
driver->StatementExecuteQuery = &CStatementExecuteQuery;
driver->StatementGetParameterSchema = &CStatementGetParameterSchema;
driver->StatementNew = &CStatementNew;
driver->StatementPrepare = &CStatementPrepare;
driver->StatementRelease = &CRelease<AdbcStatement>;
driver->StatementSetOption = &CSetOption<AdbcStatement>;
driver->StatementSetSqlQuery = &CStatementSetSqlQuery;
driver->StatementSetSubstraitPlan = &CStatementSetSubstraitPlan;
return ADBC_STATUS_OK;
}
// Driver trampolines
static AdbcStatusCode CDriverRelease(AdbcDriver* driver, AdbcError* error) {
auto driver_private = reinterpret_cast<Driver*>(driver->private_data);
delete driver_private;
driver->private_data = nullptr;
return ADBC_STATUS_OK;
}
static int CErrorGetDetailCount(const AdbcError* error) {
if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) {
return 0;
}
auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return 0;
}
return error_obj->CDetailCount();
}
static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) {
if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) {
return {nullptr, nullptr, 0};
}
auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return {nullptr, nullptr, 0};
}
return error_obj->CDetail(index);
}
// Templatable trampolines
template <typename T>
static AdbcStatusCode CNew(T* obj, AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = new ObjectT();
obj->private_data = private_data;
return ADBC_STATUS_OK;
}
template <typename T>
static AdbcStatusCode CRelease(T* obj, AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
if (obj == nullptr) return ADBC_STATUS_INVALID_STATE;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
if (private_data == nullptr) return ADBC_STATUS_INVALID_STATE;
AdbcStatusCode result = private_data->Release(error);
if (result != ADBC_STATUS_OK) {
return result;
}
delete private_data;
obj->private_data = nullptr;
return ADBC_STATUS_OK;
}
template <typename T>
static AdbcStatusCode CSetOption(T* obj, const char* key, const char* value,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CSetOption<>(key, value, error);
}
template <typename T>
static AdbcStatusCode CSetOptionBytes(T* obj, const char* key, const uint8_t* value,
size_t length, AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->CSetOptionBytes(key, value, length, error);
}
template <typename T>
static AdbcStatusCode CSetOptionInt(T* obj, const char* key, int64_t value,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CSetOption<>(key, value, error);
}
template <typename T>
static AdbcStatusCode CSetOptionDouble(T* obj, const char* key, double value,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CSetOption<>(key, value, error);
}
template <typename T>
static AdbcStatusCode CGetOption(T* obj, const char* key, char* value, size_t* length,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CGetOptionStringLike<>(key, value, length, error);
}
template <typename T>
static AdbcStatusCode CGetOptionBytes(T* obj, const char* key, uint8_t* value,
size_t* length, AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CGetOptionStringLike<>(key, value, length, error);
}
template <typename T>
static AdbcStatusCode CGetOptionInt(T* obj, const char* key, int64_t* value,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CGetOptionNumeric<>(key, value, error);
}
template <typename T>
static AdbcStatusCode CGetOptionDouble(T* obj, const char* key, double* value,
AdbcError* error) {
using ObjectT = ResolveObjectT<DatabaseT, ConnectionT, StatementT, T>;
auto private_data = reinterpret_cast<ObjectT*>(obj->private_data);
return private_data->template CGetOptionNumeric<>(key, value, error);
}
#define CHECK_INIT(DATABASE, ERROR) \
if (!(DATABASE) || !(DATABASE)->private_data) { \
return status::InvalidState("Database is uninitialized").ToAdbc(ERROR); \
}
// Database trampolines
static AdbcStatusCode CDatabaseInit(AdbcDatabase* database, AdbcError* error) {
CHECK_INIT(database, error);
auto private_data = reinterpret_cast<DatabaseT*>(database->private_data);
return private_data->Init(nullptr, error);
}
#undef CHECK_INIT
#define CHECK_INIT(CONNECTION, ERROR) \
if (!(CONNECTION) || !(CONNECTION)->private_data) { \
return status::InvalidState("Connection is uninitialized").ToAdbc(ERROR); \
}
// Connection trampolines
static AdbcStatusCode CConnectionInit(AdbcConnection* connection,
AdbcDatabase* database, AdbcError* error) {
CHECK_INIT(connection, error);
if (!database || !database->private_data) {
return status::InvalidState("Database is uninitialized").ToAdbc(error);
}
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->Init(database->private_data, error);
}
static AdbcStatusCode CConnectionCancel(AdbcConnection* connection, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->Cancel(error);
}
static AdbcStatusCode CConnectionGetInfo(AdbcConnection* connection,
const uint32_t* info_codes,
size_t info_codes_length,
ArrowArrayStream* out, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetInfo(info_codes, info_codes_length, out, error);
}
static AdbcStatusCode CConnectionGetObjects(AdbcConnection* connection, int depth,
const char* catalog, const char* db_schema,
const char* table_name,
const char** table_type,
const char* column_name,
ArrowArrayStream* out, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetObjects(depth, catalog, db_schema, table_name, table_type,
column_name, out, error);
}
static AdbcStatusCode CConnectionGetStatistics(
AdbcConnection* connection, const char* catalog, const char* db_schema,
const char* table_name, char approximate, ArrowArrayStream* out, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetStatistics(catalog, db_schema, table_name, approximate, out,
error);
}
static AdbcStatusCode CConnectionGetStatisticNames(AdbcConnection* connection,
ArrowArrayStream* out,
AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetStatisticNames(out, error);
}
static AdbcStatusCode CConnectionGetTableSchema(AdbcConnection* connection,
const char* catalog,
const char* db_schema,
const char* table_name,
ArrowSchema* schema, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetTableSchema(catalog, db_schema, table_name, schema, error);
}
static AdbcStatusCode CConnectionGetTableTypes(AdbcConnection* connection,
ArrowArrayStream* out,
AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->GetTableTypes(out, error);
}
static AdbcStatusCode CConnectionReadPartition(AdbcConnection* connection,
const uint8_t* serialized_partition,
size_t serialized_length,
ArrowArrayStream* out,
AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->ReadPartition(serialized_partition, serialized_length, out,
error);
}
static AdbcStatusCode CConnectionCommit(AdbcConnection* connection, AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->Commit(error);
}
static AdbcStatusCode CConnectionRollback(AdbcConnection* connection,
AdbcError* error) {
CHECK_INIT(connection, error);
auto private_data = reinterpret_cast<ConnectionT*>(connection->private_data);
return private_data->Rollback(error);
}
#undef CHECK_INIT
#define CHECK_INIT(STATEMENT, ERROR) \
if (!(STATEMENT) || !(STATEMENT)->private_data) { \
return status::InvalidState("Statement is uninitialized").ToAdbc(ERROR); \
}
// Statement trampolines
static AdbcStatusCode CStatementNew(AdbcConnection* connection,
AdbcStatement* statement, AdbcError* error) {
if (!connection || !connection->private_data) {
return status::InvalidState("Connection is uninitialized").ToAdbc(error);
}
auto private_data = new StatementT();
AdbcStatusCode status = private_data->Init(connection->private_data, error);
if (status != ADBC_STATUS_OK) {
delete private_data;
}
statement->private_data = private_data;
return ADBC_STATUS_OK;
}
static AdbcStatusCode CStatementBind(AdbcStatement* statement, ArrowArray* values,
ArrowSchema* schema, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->Bind(values, schema, error);
}
static AdbcStatusCode CStatementBindStream(AdbcStatement* statement,
ArrowArrayStream* stream, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->BindStream(stream, error);
}
static AdbcStatusCode CStatementCancel(AdbcStatement* statement, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->Cancel(error);
}
static AdbcStatusCode CStatementExecutePartitions(AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->ExecutePartitions(schema, partitions, rows_affected, error);
}
static AdbcStatusCode CStatementExecuteQuery(AdbcStatement* statement,
ArrowArrayStream* stream,
int64_t* rows_affected, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->ExecuteQuery(stream, rows_affected, error);
}
static AdbcStatusCode CStatementExecuteSchema(AdbcStatement* statement,
ArrowSchema* schema, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->ExecuteSchema(schema, error);
}
static AdbcStatusCode CStatementGetParameterSchema(AdbcStatement* statement,
ArrowSchema* schema,
AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->GetParameterSchema(schema, error);
}
static AdbcStatusCode CStatementPrepare(AdbcStatement* statement, AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->Prepare(error);
}
static AdbcStatusCode CStatementSetSqlQuery(AdbcStatement* statement, const char* query,
AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->SetSqlQuery(query, error);
}
static AdbcStatusCode CStatementSetSubstraitPlan(AdbcStatement* statement,
const uint8_t* plan, size_t length,
AdbcError* error) {
CHECK_INIT(statement, error);
auto private_data = reinterpret_cast<StatementT*>(statement->private_data);
return private_data->SetSubstraitPlan(plan, length, error);
}
#undef CHECK_INIT
};
template <typename Derived>
class BaseDatabase : public ObjectBase {
public:
using Base = BaseDatabase<Derived>;
BaseDatabase() : ObjectBase() {}
~BaseDatabase() = default;
/// \internal
AdbcStatusCode Init(void* parent, AdbcError* error) override {
RAISE_STATUS(error, impl().InitImpl());
return ObjectBase::Init(parent, error);
}
/// \internal
AdbcStatusCode Release(AdbcError* error) override {
RAISE_STATUS(error, impl().ReleaseImpl());
return ADBC_STATUS_OK;
}
/// \internal
AdbcStatusCode SetOption(std::string_view key, Option value,
AdbcError* error) override {
RAISE_STATUS(error, impl().SetOptionImpl(key, std::move(value)));
return ADBC_STATUS_OK;
}
/// \brief Initialize the database.
virtual Status InitImpl() { return status::Ok(); }
/// \brief Release the database.
virtual Status ReleaseImpl() { return status::Ok(); }
/// \brief Set an option. May be called prior to InitImpl.
virtual Status SetOptionImpl(std::string_view key, Option value) {
return status::NotImplemented(Derived::kErrorPrefix, " Unknown database option ", key,
"=", value.Format());
}
private:
Derived& impl() { return static_cast<Derived&>(*this); }
};
template <typename Derived>
class BaseConnection : public ObjectBase {
public:
using Base = BaseConnection<Derived>;
/// \brief Whether autocommit is enabled or not (by default: enabled).
enum class AutocommitState {
kAutocommit,
kTransaction,
};
BaseConnection() : ObjectBase() {}
~BaseConnection() = default;
/// \internal
AdbcStatusCode Init(void* parent, AdbcError* error) override {
RAISE_STATUS(error, impl().InitImpl(parent));
return ObjectBase::Init(parent, error);
}
/// \brief Initialize the database.
virtual Status InitImpl(void* parent) { return status::Ok(); }
/// \internal
AdbcStatusCode Cancel(AdbcError* error) { return impl().CancelImpl().ToAdbc(error); }
Status CancelImpl() { return status::NotImplemented("Cancel"); }
/// \internal
AdbcStatusCode Commit(AdbcError* error) { return impl().CommitImpl().ToAdbc(error); }
Status CommitImpl() { return status::NotImplemented("Commit"); }
/// \internal
AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length,
ArrowArrayStream* out, AdbcError* error) {
std::vector<uint32_t> codes(info_codes, info_codes + info_codes_length);
RAISE_STATUS(error, impl().GetInfoImpl(codes, out));
return ADBC_STATUS_OK;
}
Status GetInfoImpl(const std::vector<uint32_t> info_codes, ArrowArrayStream* out) {
return status::NotImplemented("GetInfo");
}
/// \internal
AdbcStatusCode GetObjects(int c_depth, const char* catalog, const char* db_schema,
const char* table_name, const char** table_type,
const char* column_name, ArrowArrayStream* out,
AdbcError* error) {
const auto catalog_filter =
catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt;
const auto schema_filter =
db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt;
const auto table_filter =
table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt;
const auto column_filter =
column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt;
std::vector<std::string_view> table_type_filter;
while (table_type && *table_type) {
if (*table_type) {
table_type_filter.push_back(std::string_view(*table_type));
}
table_type++;
}
RAISE_STATUS(
error, impl().GetObjectsImpl(c_depth, catalog_filter, schema_filter, table_filter,
column_filter, table_type_filter, out));
return ADBC_STATUS_OK;
}
Status GetObjectsImpl(int c_depth, std::optional<std::string_view> catalog_filter,
std::optional<std::string_view> schema_filter,
std::optional<std::string_view> table_filter,
std::optional<std::string_view> column_filter,
const std::vector<std::string_view>& table_types,
struct ArrowArrayStream* out) {
return status::NotImplemented("GetObjects");
}
/// \internal
AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema,
const char* table_name, char approximate,
ArrowArrayStream* out, AdbcError* error) {
const auto catalog_filter =
catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt;
const auto schema_filter =
db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt;
const auto table_filter =
table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt;
RAISE_STATUS(error, impl().GetStatisticsImpl(catalog_filter, schema_filter,
table_filter, approximate != 0, out));
return ADBC_STATUS_OK;
}
Status GetStatisticsImpl(std::optional<std::string_view> catalog,
std::optional<std::string_view> db_schema,
std::optional<std::string_view> table_name, bool approximate,
ArrowArrayStream* out) {
return status::NotImplemented("GetStatistics");
}
/// \internal
AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) {
RAISE_STATUS(error, impl().GetStatisticNames(out));
return ADBC_STATUS_OK;
}
Status GetStatisticNames(ArrowArrayStream* out) {
return status::NotImplemented("GetStatisticNames");
}
/// \internal
AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema,
const char* table_name, ArrowSchema* schema,
AdbcError* error) {
if (!table_name) {
return status::InvalidArgument(Derived::kErrorPrefix,
" GetTableSchema: must provide table_name")
.ToAdbc(error);
}
std::optional<std::string_view> catalog_param =
catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt;
std::optional<std::string_view> db_schema_param =
db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt;
RAISE_STATUS(error, impl().GetTableSchemaImpl(catalog_param, db_schema_param,
table_name, schema));
return ADBC_STATUS_OK;
}
Status GetTableSchemaImpl(std::optional<std::string_view> catalog,
std::optional<std::string_view> db_schema,
std::string_view table_name, ArrowSchema* out) {
return status::NotImplemented("GetTableSchema");
}
/// \internal
AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) {
RAISE_STATUS(error, impl().GetTableTypesImpl(out));
return ADBC_STATUS_OK;
}
Status GetTableTypesImpl(ArrowArrayStream* out) {
return status::NotImplemented("GetTableTypes");
}
/// \internal
AdbcStatusCode ReadPartition(const uint8_t* serialized_partition,
size_t serialized_length, ArrowArrayStream* out,
AdbcError* error) {
std::string_view partition(reinterpret_cast<const char*>(serialized_partition),
serialized_length);
RAISE_STATUS(error, impl().ReadPartitionImpl(partition, out));
return ADBC_STATUS_OK;
}
Status ReadPartitionImpl(std::string_view serialized_partition, ArrowArrayStream* out) {
return status::NotImplemented("ReadPartition");
}
/// \internal
AdbcStatusCode Release(AdbcError* error) override {
RAISE_STATUS(error, impl().ReleaseImpl());
return ADBC_STATUS_OK;
}
Status ReleaseImpl() { return status::Ok(); }
/// \internal
AdbcStatusCode Rollback(AdbcError* error) {
RAISE_STATUS(error, impl().RollbackImpl());
return ADBC_STATUS_OK;
}
Status RollbackImpl() { return status::NotImplemented("Rollback"); }
/// \internal
AdbcStatusCode SetOption(std::string_view key, Option value,
AdbcError* error) override {
RAISE_STATUS(error, impl().SetOptionImpl(key, value));
return ADBC_STATUS_OK;
}
/// \brief Set an option. May be called prior to InitImpl.
virtual Status SetOptionImpl(std::string_view key, Option value) {
return status::NotImplemented(Derived::kErrorPrefix, " Unknown connection option ",
key, "=", value.Format());
}
private:
Derived& impl() { return static_cast<Derived&>(*this); }
};
template <typename Derived>
class BaseStatement : public ObjectBase {
public:
using Base = BaseStatement<Derived>;
/// \internal
AdbcStatusCode Init(void* parent, AdbcError* error) override {
RAISE_STATUS(error, impl().InitImpl(parent));
return ObjectBase::Init(parent, error);
}
/// \brief Initialize the statement.
Status InitImpl(void* parent) { return status::Ok(); }
/// \internal
AdbcStatusCode Release(AdbcError* error) override {
RAISE_STATUS(error, impl().ReleaseImpl());
return ADBC_STATUS_OK;
}
Status ReleaseImpl() { return status::Ok(); }
/// \internal
AdbcStatusCode SetOption(std::string_view key, Option value,
AdbcError* error) override {
RAISE_STATUS(error, impl().SetOptionImpl(key, value));
return ADBC_STATUS_OK;
}
/// \brief Set an option. May be called prior to InitImpl.
virtual Status SetOptionImpl(std::string_view key, Option value) {
return status::NotImplemented(Derived::kErrorPrefix, " Unknown statement option ",
key, "=", value.Format());
}
AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected,
AdbcError* error) {
RAISE_RESULT(error, int64_t rows_affected_result, impl().ExecuteQueryImpl(stream));
if (rows_affected) {
*rows_affected = rows_affected_result;
}
return ADBC_STATUS_OK;
}
Result<int64_t> ExecuteQueryImpl(ArrowArrayStream* stream) {
return status::NotImplemented("ExecuteQuery");
}
AdbcStatusCode ExecuteSchema(ArrowSchema* schema, AdbcError* error) {
RAISE_STATUS(error, impl().ExecuteSchemaImpl(schema));
return ADBC_STATUS_OK;
}
Status ExecuteSchemaImpl(ArrowSchema* schema) {
return status::NotImplemented("ExecuteSchema");
}
AdbcStatusCode Prepare(AdbcError* error) {
RAISE_STATUS(error, impl().PrepareImpl());
return ADBC_STATUS_OK;
}
Status PrepareImpl() { return status::NotImplemented("Prepare"); }
AdbcStatusCode SetSqlQuery(const char* query, AdbcError* error) {
RAISE_STATUS(error, impl().SetSqlQueryImpl(query));
return ADBC_STATUS_OK;
}
Status SetSqlQueryImpl(std::string_view query) {
return status::NotImplemented("SetSqlQuery");
}
AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length, AdbcError* error) {
RAISE_STATUS(error, impl().SetSubstraitPlanImpl(std::string_view(
reinterpret_cast<const char*>(plan), length)));
return ADBC_STATUS_OK;
}
Status SetSubstraitPlanImpl(std::string_view plan) {
return status::NotImplemented("SetSubstraitPlan");
}
AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) {
RAISE_STATUS(error, impl().BindImpl(values, schema));
return ADBC_STATUS_OK;
}
Status BindImpl(ArrowArray* values, ArrowSchema* schema) {
return status::NotImplemented("Bind");
}
AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) {
RAISE_STATUS(error, impl().BindStreamImpl(stream));
return ADBC_STATUS_OK;
}
Status BindStreamImpl(ArrowArrayStream* stream) {
return status::NotImplemented("BindStream");
}
AdbcStatusCode GetParameterSchema(ArrowSchema* schema, AdbcError* error) {
RAISE_STATUS(error, impl().GetParameterSchemaImpl(schema));
return ADBC_STATUS_OK;
}
Status GetParameterSchemaImpl(struct ArrowSchema* schema) {
return status::NotImplemented("GetParameterSchema");
}
AdbcStatusCode ExecutePartitions(ArrowSchema* schema, AdbcPartitions* partitions,
int64_t* rows_affected, AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
AdbcStatusCode Cancel(AdbcError* error) {
RAISE_STATUS(error, impl().Cancel());
return ADBC_STATUS_OK;
}
Status Cancel() { return status::NotImplemented("Cancel"); }
private:
Derived& impl() { return static_cast<Derived&>(*this); }
};
} // namespace adbc::driver