c/driver/postgresql/statement.cc (908 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "statement.h"
#include <array>
#include <cerrno>
#include <cinttypes>
#include <cstring>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <adbc.h>
#include <libpq-fe.h>
#include <nanoarrow/nanoarrow.hpp>
#include "common/utils.h"
#include "connection.h"
#include "postgres_copy_reader.h"
#include "postgres_type.h"
#include "postgres_util.h"
#include "vendor/portable-snippets/safe-math.h"
namespace adbcpq {
namespace {
/// The flag indicating to PostgreSQL that we want binary-format values.
constexpr int kPgBinaryFormat = 1;
/// One-value ArrowArrayStream used to unify the implementations of Bind
struct OneValueStream {
struct ArrowSchema schema;
struct ArrowArray array;
static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) {
OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
return ArrowSchemaDeepCopy(&stream->schema, out);
}
static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) {
OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
*out = stream->array;
stream->array.release = nullptr;
return 0;
}
static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; }
static void Release(struct ArrowArrayStream* self) {
OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
if (stream->schema.release) {
stream->schema.release(&stream->schema);
stream->schema.release = nullptr;
}
if (stream->array.release) {
stream->array.release(&stream->array);
stream->array.release = nullptr;
}
delete stream;
self->release = nullptr;
}
};
/// Helper to manage resources with RAII
template <typename T>
struct Releaser {
static void Release(T* value) {
if (value->release) {
value->release(value);
}
}
};
template <>
struct Releaser<struct ArrowArrayView> {
static void Release(struct ArrowArrayView* value) {
if (value->storage_type != NANOARROW_TYPE_UNINITIALIZED) {
ArrowArrayViewReset(value);
}
}
};
template <typename Resource>
struct Handle {
Resource value;
Handle() { std::memset(&value, 0, sizeof(value)); }
~Handle() { Releaser<Resource>::Release(&value); }
Resource* operator->() { return &value; }
};
/// Build an PostgresType object from a PGresult*
AdbcStatusCode ResolvePostgresType(const PostgresTypeResolver& type_resolver,
PGresult* result, PostgresType* out,
struct AdbcError* error) {
ArrowError na_error;
const int num_fields = PQnfields(result);
PostgresType root_type(PostgresTypeId::kRecord);
for (int i = 0; i < num_fields; i++) {
const Oid pg_oid = PQftype(result, i);
PostgresType pg_type;
if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
SetError(error, "%s%d%s%s%s%d", "[libpq] Column #", i + 1, " (\"",
PQfname(result, i), "\") has unknown type code ", pg_oid);
return ADBC_STATUS_NOT_IMPLEMENTED;
}
root_type.AppendChild(PQfname(result, i), pg_type);
}
*out = root_type;
return ADBC_STATUS_OK;
}
/// Helper to manage bind parameters with a prepared statement
struct BindStream {
Handle<struct ArrowArrayStream> bind;
Handle<struct ArrowSchema> bind_schema;
struct ArrowSchemaView bind_schema_view;
std::vector<struct ArrowSchemaView> bind_schema_fields;
// OIDs for parameter types
std::vector<uint32_t> param_types;
std::vector<char*> param_values;
std::vector<int> param_lengths;
std::vector<int> param_formats;
std::vector<size_t> param_values_offsets;
std::vector<char> param_values_buffer;
// XXX: this assumes fixed-length fields only - will need more
// consideration to deal with variable-length fields
bool has_tz_field = false;
std::string tz_setting;
struct ArrowError na_error;
explicit BindStream(struct ArrowArrayStream&& bind) {
this->bind.value = std::move(bind);
std::memset(&na_error, 0, sizeof(na_error));
}
template <typename Callback>
AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) {
CHECK_NA(INTERNAL, bind->get_schema(&bind.value, &bind_schema.value), error);
CHECK_NA(
INTERNAL,
ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, /*error*/ nullptr),
error);
if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) {
SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT");
return ADBC_STATUS_INVALID_STATE;
}
bind_schema_fields.resize(bind_schema->n_children);
for (size_t i = 0; i < bind_schema_fields.size(); i++) {
CHECK_NA(INTERNAL,
ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i],
/*error*/ nullptr),
error);
}
return std::move(callback)();
}
AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver,
struct AdbcError* error) {
param_types.resize(bind_schema->n_children);
param_values.resize(bind_schema->n_children);
param_lengths.resize(bind_schema->n_children);
param_formats.resize(bind_schema->n_children, kPgBinaryFormat);
param_values_offsets.reserve(bind_schema->n_children);
for (size_t i = 0; i < bind_schema_fields.size(); i++) {
PostgresTypeId type_id;
switch (bind_schema_fields[i].type) {
case ArrowType::NANOARROW_TYPE_INT8:
case ArrowType::NANOARROW_TYPE_INT16:
type_id = PostgresTypeId::kInt2;
param_lengths[i] = 2;
break;
case ArrowType::NANOARROW_TYPE_INT32:
type_id = PostgresTypeId::kInt4;
param_lengths[i] = 4;
break;
case ArrowType::NANOARROW_TYPE_INT64:
type_id = PostgresTypeId::kInt8;
param_lengths[i] = 8;
break;
case ArrowType::NANOARROW_TYPE_FLOAT:
type_id = PostgresTypeId::kFloat4;
param_lengths[i] = 4;
break;
case ArrowType::NANOARROW_TYPE_DOUBLE:
type_id = PostgresTypeId::kFloat8;
param_lengths[i] = 8;
break;
case ArrowType::NANOARROW_TYPE_STRING:
type_id = PostgresTypeId::kText;
param_lengths[i] = 0;
break;
case ArrowType::NANOARROW_TYPE_BINARY:
type_id = PostgresTypeId::kBytea;
param_lengths[i] = 0;
break;
case ArrowType::NANOARROW_TYPE_DATE32:
type_id = PostgresTypeId::kDate;
param_lengths[i] = 4;
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
break;
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
"') has unsupported parameter type ",
ArrowTypeString(bind_schema_fields[i].type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
param_types[i] = type_resolver.GetOID(type_id);
if (param_types[i] == 0) {
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
"') has type with no corresponding PostgreSQL type ",
ArrowTypeString(bind_schema_fields[i].type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}
size_t param_values_length = 0;
for (int length : param_lengths) {
param_values_offsets.push_back(param_values_length);
param_values_length += length;
}
param_values_buffer.resize(param_values_length);
return ADBC_STATUS_OK;
}
AdbcStatusCode Prepare(PGconn* conn, const std::string& query, struct AdbcError* error,
const bool autocommit) {
// tz-aware timestamps require special handling to set the timezone to UTC
// prior to sending over the binary protocol; must be reset after execute
for (int64_t col = 0; col < bind_schema->n_children; col++) {
if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) &&
(strcmp("", bind_schema_fields[col].timezone))) {
has_tz_field = true;
if (autocommit) {
PGresult* begin_result = PQexec(conn, "BEGIN");
if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to begin transaction for timezone data: %s",
PQerrorMessage(conn));
PQclear(begin_result);
return ADBC_STATUS_IO;
}
PQclear(begin_result);
}
PGresult* get_tz_result = PQexec(conn, "SELECT current_setting('TIMEZONE')");
if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) {
SetError(error, "[libpq] Could not query current timezone: %s",
PQerrorMessage(conn));
PQclear(get_tz_result);
return ADBC_STATUS_IO;
}
tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0));
PQclear(get_tz_result);
PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'");
if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to set time zone to UTC: %s",
PQerrorMessage(conn));
PQclear(set_utc_result);
return ADBC_STATUS_IO;
}
PQclear(set_utc_result);
break;
}
}
PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(),
/*nParams=*/bind_schema->n_children, param_types.data());
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s",
PQerrorMessage(conn), query.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
return ADBC_STATUS_OK;
}
AdbcStatusCode Execute(PGconn* conn, int64_t* rows_affected, struct AdbcError* error) {
if (rows_affected) *rows_affected = 0;
PGresult* result = nullptr;
while (true) {
Handle<struct ArrowArray> array;
int res = bind->get_next(&bind.value, &array.value);
if (res != 0) {
SetError(error,
"[libpq] Failed to read next batch from stream of bind parameters: "
"(%d) %s %s",
res, std::strerror(res), bind->get_last_error(&bind.value));
return ADBC_STATUS_IO;
}
if (!array->release) break;
Handle<struct ArrowArrayView> array_view;
// TODO: include error messages
CHECK_NA(
INTERNAL,
ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr),
error);
CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr),
error);
for (int64_t row = 0; row < array->length; row++) {
for (int64_t col = 0; col < array_view->n_children; col++) {
if (ArrowArrayViewIsNull(array_view->children[col], row)) {
param_values[col] = nullptr;
continue;
} else {
param_values[col] = param_values_buffer.data() + param_values_offsets[col];
}
switch (bind_schema_fields[col].type) {
case ArrowType::NANOARROW_TYPE_INT8: {
const int16_t val =
array_view->children[col]->buffer_views[1].data.as_int8[row];
const uint16_t value = ToNetworkInt16(val);
std::memcpy(param_values[col], &value, sizeof(int16_t));
break;
}
case ArrowType::NANOARROW_TYPE_INT16: {
const uint16_t value = ToNetworkInt16(
array_view->children[col]->buffer_views[1].data.as_int16[row]);
std::memcpy(param_values[col], &value, sizeof(int16_t));
break;
}
case ArrowType::NANOARROW_TYPE_INT32: {
const uint32_t value = ToNetworkInt32(
array_view->children[col]->buffer_views[1].data.as_int32[row]);
std::memcpy(param_values[col], &value, sizeof(int32_t));
break;
}
case ArrowType::NANOARROW_TYPE_INT64: {
const int64_t value = ToNetworkInt64(
array_view->children[col]->buffer_views[1].data.as_int64[row]);
std::memcpy(param_values[col], &value, sizeof(int64_t));
break;
}
case ArrowType::NANOARROW_TYPE_FLOAT: {
const uint32_t value = ToNetworkFloat4(
array_view->children[col]->buffer_views[1].data.as_float[row]);
std::memcpy(param_values[col], &value, sizeof(uint32_t));
break;
}
case ArrowType::NANOARROW_TYPE_DOUBLE: {
const uint64_t value = ToNetworkFloat8(
array_view->children[col]->buffer_views[1].data.as_double[row]);
std::memcpy(param_values[col], &value, sizeof(uint64_t));
break;
}
case ArrowType::NANOARROW_TYPE_STRING:
case ArrowType::NANOARROW_TYPE_BINARY: {
const ArrowBufferView view =
ArrowArrayViewGetBytesUnsafe(array_view->children[col], row);
// TODO: overflow check?
param_lengths[col] = static_cast<int>(view.size_bytes);
param_values[col] = const_cast<char*>(view.data.as_char);
break;
}
case ArrowType::NANOARROW_TYPE_DATE32: {
// 2000-01-01
constexpr int32_t kPostgresDateEpoch = 10957;
const int32_t raw_value =
array_view->children[col]->buffer_views[1].data.as_int32[row];
if (raw_value < INT32_MIN + kPostgresDateEpoch) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1,
"('", bind_schema->children[col]->name, "') Row #", row + 1,
"has value which exceeds postgres date limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch);
std::memcpy(param_values[col], &value, sizeof(int32_t));
break;
}
case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row];
// 2000-01-01 00:00:00.000000 in microseconds
constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
psnip_safe_bool overflow_safe = true;
auto unit = bind_schema_fields[col].time_unit;
switch (unit) {
case NANOARROW_TIME_UNIT_SECOND:
overflow_safe = psnip_safe_int64_mul(&val, val, 1000000);
break;
case NANOARROW_TIME_UNIT_MILLI:
overflow_safe = psnip_safe_int64_mul(&val, val, 1000);
break;
case NANOARROW_TIME_UNIT_MICRO:
break;
case NANOARROW_TIME_UNIT_NANO:
val /= 1000;
break;
}
if (!overflow_safe) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1,
" (' ", bind_schema->children[col]->name, " ') Row # ", row + 1,
" has value which exceeds postgres timestamp limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch);
std::memcpy(param_values[col], &value, sizeof(int64_t));
break;
}
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
struct ArrowInterval interval;
ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO);
ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval);
const uint32_t months = ToNetworkInt32(interval.months);
const uint32_t days = ToNetworkInt32(interval.days);
const uint64_t ms = ToNetworkInt64(interval.ns / 1000);
std::memcpy(param_values[col], &ms, sizeof(uint64_t));
std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t));
std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t),
&months, sizeof(uint32_t));
break;
}
default:
SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
bind_schema->children[col]->name,
"') has unsupported type for ingestion ",
ArrowTypeString(bind_schema_fields[col].type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}
result = PQexecPrepared(conn, /*stmtName=*/"",
/*nParams=*/bind_schema->n_children, param_values.data(),
param_lengths.data(), param_formats.data(),
/*resultFormat=*/0 /*text*/);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error, "%s%s", "[libpq] Failed to execute prepared statement: ",
PQerrorMessage(conn));
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
}
if (rows_affected) *rows_affected += array->length;
if (has_tz_field) {
std::string reset_query = "SET TIME ZONE '" + tz_setting + "'";
PGresult* reset_tz_result = PQexec(conn, reset_query.c_str());
if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to reset time zone: %s", PQerrorMessage(conn));
PQclear(reset_tz_result);
return ADBC_STATUS_IO;
}
PQclear(reset_tz_result);
PGresult* commit_result = PQexec(conn, "COMMIT");
if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to commit transaction: %s",
PQerrorMessage(conn));
PQclear(commit_result);
return ADBC_STATUS_IO;
}
PQclear(commit_result);
}
}
return ADBC_STATUS_OK;
}
};
} // namespace
int TupleReader::GetSchema(struct ArrowSchema* out) {
int na_res = copy_reader_->GetSchema(out);
if (out->release == nullptr) {
StringBuilderAppend(&error_builder_,
"[libpq] Result set was already consumed or freed");
return EINVAL;
} else if (na_res != NANOARROW_OK) {
// e.g., Can't allocate memory
StringBuilderAppend(&error_builder_, "[libpq] Error copying schema");
}
return na_res;
}
int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) {
ResetQuery();
// Fetch + parse the header
int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
data_.size_bytes = get_copy_res;
data_.data.as_char = pgbuf_;
if (get_copy_res == -2) {
StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s",
PQerrorMessage(conn_));
return EIO;
}
int na_res = copy_reader_->ReadHeader(&data_, error);
if (na_res != NANOARROW_OK) {
StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", error->message);
return EIO;
}
return NANOARROW_OK;
}
int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) {
// Parse the result (the header AND the first row are included in the first
// call to PQgetCopyData())
int na_res = copy_reader_->ReadRecord(&data_, error);
if (na_res != NANOARROW_OK && na_res != ENODATA) {
StringBuilderAppend(&error_builder_,
"[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_,
error->message);
return na_res;
}
row_id_++;
// Fetch + check
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
data_.size_bytes = get_copy_res;
data_.data.as_char = pgbuf_;
if (get_copy_res == -2) {
StringBuilderAppend(&error_builder_,
"[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_,
PQerrorMessage(conn_));
return EIO;
} else if (get_copy_res == -1) {
// Returned when COPY has finished successfully
return ENODATA;
} else if ((copy_reader_->array_size_approx_bytes() + get_copy_res) >=
batch_size_hint_bytes_) {
// Appending the next row will result in an array larger than requested.
// Return EOVERFLOW to force GetNext() to build the current result and return.
return EOVERFLOW;
} else {
return NANOARROW_OK;
}
}
int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) {
if (copy_reader_->array_size_approx_bytes() == 0) {
out->release = nullptr;
return NANOARROW_OK;
}
int na_res = copy_reader_->GetArray(out, error);
if (na_res != NANOARROW_OK) {
StringBuilderAppend(&error_builder_, "[libpq] Failed to build result array: %s",
error->message);
return na_res;
}
return NANOARROW_OK;
}
void TupleReader::ResetQuery() {
// Clear result
if (result_) {
PQclear(result_);
result_ = nullptr;
}
// Reset result buffer
if (pgbuf_ != nullptr) {
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
}
// Clear the error builder
error_builder_.size = 0;
row_id_ = -1;
}
int TupleReader::GetNext(struct ArrowArray* out) {
if (!copy_reader_) {
out->release = nullptr;
return 0;
}
struct ArrowError error;
error.message[0] = '\0';
if (row_id_ == -1) {
NANOARROW_RETURN_NOT_OK(InitQueryAndFetchFirst(&error));
row_id_++;
}
int na_res;
do {
na_res = AppendRowAndFetchNext(&error);
if (na_res == EOVERFLOW) {
// The result would be too big to return if we appended the row. When EOVERFLOW is
// returned, the copy reader leaves the output in a valid state. The data is left in
// pg_buf_/data_ and will attempt to be appended on the next call to GetNext()
return BuildOutput(out, &error);
}
} while (na_res == NANOARROW_OK);
if (na_res != ENODATA) {
return na_res;
}
// Finish the result properly and return the last result. Note that BuildOutput() may
// set tmp.release = nullptr if there were zero rows in the copy reader (can
// occur in an overflow scenario).
struct ArrowArray tmp;
NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error));
// Clear the copy reader to mark this reader as finished
copy_reader_.reset();
// Check the server-side response
result_ = PQgetResult(conn_);
const int pq_status = PQresultStatus(result_);
if (pq_status != PGRES_COMMAND_OK) {
StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", pq_status,
PQresultErrorMessage(result_));
if (tmp.release != nullptr) {
tmp.release(&tmp);
}
return EIO;
}
ResetQuery();
ArrowArrayMove(&tmp, out);
return NANOARROW_OK;
}
void TupleReader::Release() {
StringBuilderReset(&error_builder_);
if (result_) {
PQclear(result_);
result_ = nullptr;
}
if (pgbuf_) {
PQfreemem(pgbuf_);
pgbuf_ = nullptr;
}
}
void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
stream->get_schema = &GetSchemaTrampoline;
stream->get_next = &GetNextTrampoline;
stream->get_last_error = &GetLastErrorTrampoline;
stream->release = &ReleaseTrampoline;
stream->private_data = this;
}
int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self,
struct ArrowSchema* out) {
if (!self || !self->private_data) return EINVAL;
TupleReader* reader = static_cast<TupleReader*>(self->private_data);
return reader->GetSchema(out);
}
int TupleReader::GetNextTrampoline(struct ArrowArrayStream* self,
struct ArrowArray* out) {
if (!self || !self->private_data) return EINVAL;
TupleReader* reader = static_cast<TupleReader*>(self->private_data);
return reader->GetNext(out);
}
const char* TupleReader::GetLastErrorTrampoline(struct ArrowArrayStream* self) {
if (!self || !self->private_data) return nullptr;
TupleReader* reader = static_cast<TupleReader*>(self->private_data);
return reader->last_error();
}
void TupleReader::ReleaseTrampoline(struct ArrowArrayStream* self) {
if (!self || !self->private_data) return;
TupleReader* reader = static_cast<TupleReader*>(self->private_data);
reader->Release();
self->private_data = nullptr;
self->release = nullptr;
}
AdbcStatusCode PostgresStatement::New(struct AdbcConnection* connection,
struct AdbcError* error) {
if (!connection || !connection->private_data) {
SetError(error, "%s", "[libpq] Must provide an initialized AdbcConnection");
return ADBC_STATUS_INVALID_ARGUMENT;
}
connection_ =
*reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
type_resolver_ = connection_->type_resolver();
reader_.conn_ = connection_->conn();
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::Bind(struct ArrowArray* values,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (!values || !values->release) {
SetError(error, "%s", "[libpq] Must provide non-NULL array");
return ADBC_STATUS_INVALID_ARGUMENT;
} else if (!schema || !schema->release) {
SetError(error, "%s", "[libpq] Must provide non-NULL schema");
return ADBC_STATUS_INVALID_ARGUMENT;
}
if (bind_.release) bind_.release(&bind_);
// Make a one-value stream
bind_.private_data = new OneValueStream{*schema, *values};
bind_.get_schema = &OneValueStream::GetSchema;
bind_.get_next = &OneValueStream::GetNext;
bind_.get_last_error = &OneValueStream::GetLastError;
bind_.release = &OneValueStream::Release;
std::memset(values, 0, sizeof(*values));
std::memset(schema, 0, sizeof(*schema));
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::Bind(struct ArrowArrayStream* stream,
struct AdbcError* error) {
if (!stream || !stream->release) {
SetError(error, "%s", "[libpq] Must provide non-NULL stream");
return ADBC_STATUS_INVALID_ARGUMENT;
}
// Move stream
if (bind_.release) bind_.release(&bind_);
bind_ = *stream;
std::memset(stream, 0, sizeof(*stream));
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::CreateBulkTable(
const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
struct AdbcError* error) {
std::string create = "CREATE TABLE ";
create += ingest_.target;
create += " (";
for (size_t i = 0; i < source_schema_fields.size(); i++) {
if (i > 0) create += ", ";
create += source_schema.children[i]->name;
switch (source_schema_fields[i].type) {
case ArrowType::NANOARROW_TYPE_INT8:
case ArrowType::NANOARROW_TYPE_INT16:
create += " SMALLINT";
break;
case ArrowType::NANOARROW_TYPE_INT32:
create += " INTEGER";
break;
case ArrowType::NANOARROW_TYPE_INT64:
create += " BIGINT";
break;
case ArrowType::NANOARROW_TYPE_FLOAT:
create += " REAL";
break;
case ArrowType::NANOARROW_TYPE_DOUBLE:
create += " DOUBLE PRECISION";
break;
case ArrowType::NANOARROW_TYPE_STRING:
create += " TEXT";
break;
case ArrowType::NANOARROW_TYPE_BINARY:
create += " BYTEA";
break;
case ArrowType::NANOARROW_TYPE_DATE32:
create += " DATE";
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
if (strcmp("", source_schema_fields[i].timezone)) {
create += " TIMESTAMPTZ";
} else {
create += " TIMESTAMP";
}
break;
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
create += " INTERVAL";
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
"') has unsupported type for ingestion ",
ArrowTypeString(source_schema_fields[i].type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}
create += ")";
SetError(error, "%s%s", "[libpq] ", create.c_str());
PGresult* result = PQexecParams(connection_->conn(), create.c_str(), /*nParams=*/0,
/*paramTypes=*/nullptr, /*paramValues=*/nullptr,
/*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
/*resultFormat=*/1 /*(binary)*/);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s",
PQerrorMessage(connection_->conn()), create.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::ExecutePreparedStatement(
struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) {
if (!bind_.release) {
// TODO: set an empty stream just to unify the code paths
SetError(error, "%s",
"[libpq] Prepared statements without parameters are not implemented");
return ADBC_STATUS_NOT_IMPLEMENTED;
}
if (stream) {
// TODO:
SetError(error, "%s",
"[libpq] Prepared statements returning result sets are not implemented");
return ADBC_STATUS_NOT_IMPLEMENTED;
}
BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
RAISE_ADBC(
bind_stream.Prepare(connection_->conn(), query_, error, connection_->autocommit()));
RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error));
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
int64_t* rows_affected,
struct AdbcError* error) {
ClearResult();
if (prepared_) {
if (bind_.release || !stream) {
return ExecutePreparedStatement(stream, rows_affected, error);
}
// XXX: don't use a prepared statement to execute a no-parameter
// result-set-returning query for now, since we can't easily get
// access to COPY there. (This might have to become sequential
// executions of COPY (EXECUTE ($n, ...)) TO STDOUT which won't
// get all the benefits of a prepared statement.) At preparation
// time we don't know whether the query will be used with a result
// set or not without analyzing the query (we could prepare both?)
// and https://stackoverflow.com/questions/69233792 suggests that
// you can't PREPARE a query containing COPY.
}
if (!stream && !ingest_.target.empty()) {
return ExecuteUpdateBulk(rows_affected, error);
}
if (query_.empty()) {
SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery");
return ADBC_STATUS_INVALID_STATE;
}
// 1. Prepare the query to get the schema
{
// TODO: we should pipeline here and assume this will succeed
PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(),
/*nParams=*/0, nullptr);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"prepare query: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"describe prepared statement: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
// Resolve the information from the PGresult into a PostgresType
PostgresType root_type;
AdbcStatusCode status =
ResolvePostgresType(*type_resolver_, result, &root_type, error);
PQclear(result);
if (status != ADBC_STATUS_OK) return status;
// Initialize the copy reader and infer the output schema (i.e., error for
// unsupported types before issuing the COPY query)
reader_.copy_reader_.reset(new PostgresCopyStreamReader());
reader_.copy_reader_->Init(root_type);
struct ArrowError na_error;
int na_res = reader_.copy_reader_->InferOutputSchema(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message);
return na_res;
}
// If the caller did not request a result set or if there are no
// inferred output columns (e.g. a CREATE or UPDATE), then don't
// use COPY (which would fail anyways)
if (!stream || root_type.n_children() == 0) {
RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error));
if (stream) {
struct ArrowSchema schema;
std::memset(&schema, 0, sizeof(schema));
RAISE_NA(reader_.copy_reader_->GetSchema(&schema));
nanoarrow::EmptyArrayStream::MakeUnique(&schema).move(stream);
}
return ADBC_STATUS_OK;
}
// This resolves the reader specific to each PostgresType -> ArrowSchema
// conversion. It is unlikely that this will fail given that we have just
// inferred these conversions ourselves.
na_res = reader_.copy_reader_->InitFieldReaders(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message);
return na_res;
}
}
// 2. Execute the query with COPY to get binary tuples
{
std::string copy_query = "COPY (" + query_ + ") TO STDOUT (FORMAT binary)";
reader_.result_ =
PQexecParams(connection_->conn(), copy_query.c_str(), /*nParams=*/0,
/*paramTypes=*/nullptr, /*paramValues=*/nullptr,
/*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat);
if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) {
SetError(error,
"[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s",
PQerrorMessage(connection_->conn()), copy_query.c_str());
ClearResult();
return ADBC_STATUS_IO;
}
// Result is read from the connection, not the result, but we won't clear it here
}
reader_.ExportTo(stream);
if (rows_affected) *rows_affected = -1;
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
struct AdbcError* error) {
if (!bind_.release) {
SetError(error, "%s", "[libpq] Must Bind() before Execute() for bulk ingestion");
return ADBC_STATUS_INVALID_STATE;
}
BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
RAISE_ADBC(bind_stream.Begin(
[&]() -> AdbcStatusCode {
if (!ingest_.append) {
// CREATE TABLE
return CreateBulkTable(bind_stream.bind_schema.value,
bind_stream.bind_schema_fields, error);
}
return ADBC_STATUS_OK;
},
error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));
std::string insert = "INSERT INTO ";
insert += ingest_.target;
insert += " VALUES (";
for (size_t i = 0; i < bind_stream.bind_schema_fields.size(); i++) {
if (i > 0) insert += ", ";
insert += "$";
insert += std::to_string(i + 1);
}
insert += ")";
RAISE_ADBC(
bind_stream.Prepare(connection_->conn(), insert, error, connection_->autocommit()));
RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error));
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected,
struct AdbcError* error) {
// NOTE: must prepare first (used in ExecuteQuery)
PGresult* result =
PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0,
/*paramValues=*/nullptr, /*paramLengths=*/nullptr,
/*paramFormats=*/nullptr, /*resultFormat=*/kPgBinaryFormat);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
if (rows_affected) *rows_affected = PQntuples(reader_.result_);
PQclear(result);
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::GetParameterSchema(struct ArrowSchema* schema,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
AdbcStatusCode PostgresStatement::Prepare(struct AdbcError* error) {
if (query_.empty()) {
SetError(error, "%s", "[libpq] Must SetSqlQuery() before Prepare()");
return ADBC_STATUS_INVALID_STATE;
}
// Don't actually prepare until execution time, so we know the
// parameter types
prepared_ = true;
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::Release(struct AdbcError* error) {
ClearResult();
if (bind_.release) {
bind_.release(&bind_);
}
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::SetSqlQuery(const char* query,
struct AdbcError* error) {
ingest_.target.clear();
query_ = query;
prepared_ = false;
return ADBC_STATUS_OK;
}
AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
struct AdbcError* error) {
if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) {
query_.clear();
ingest_.target = value;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) {
if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) {
ingest_.append = false;
} else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) {
ingest_.append = true;
} else {
SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key);
return ADBC_STATUS_INVALID_ARGUMENT;
}
} else if (std::strcmp(value, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES)) {
int64_t int_value = std::atol(value);
if (int_value <= 0) {
SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key);
return ADBC_STATUS_INVALID_ARGUMENT;
}
this->reader_.batch_size_hint_bytes_ = int_value;
} else {
SetError(error, "[libq] Unknown statement option '%s'", key);
return ADBC_STATUS_NOT_IMPLEMENTED;
}
return ADBC_STATUS_OK;
}
void PostgresStatement::ClearResult() {
// TODO: we may want to synchronize here for safety
reader_.Release();
}
} // namespace adbcpq