c/driver/postgresql/connection.cc (1,014 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #include "connection.h" #include <array> #include <cassert> #include <cinttypes> #include <cmath> #include <cstring> #include <memory> #include <optional> #include <sstream> #include <string> #include <string_view> #include <unordered_map> #include <utility> #include <vector> #include <arrow-adbc/adbc.h> #include <fmt/format.h> #include <libpq-fe.h> #include "database.h" #include "driver/common/utils.h" #include "driver/framework/objects.h" #include "driver/framework/utility.h" #include "error.h" #include "result_helper.h" using adbc::driver::Result; using adbc::driver::Status; namespace adbcpq { namespace { constexpr std::string_view kConnectionOptionTransactionStatus = "adbc.postgresql.transaction_status"; static const uint32_t kSupportedInfoCodes[] = { ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, }; static const std::unordered_map<std::string, std::string> kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; static const char* kCatalogQueryAll = "SELECT datname FROM pg_catalog.pg_database"; // catalog_name is not a parameter here or on any other queries // because it will always be the currently connected database. static const char* kSchemaQueryAll = "SELECT nspname FROM pg_catalog.pg_namespace WHERE " "nspname !~ '^pg_' AND nspname <> 'information_schema'"; // Parameterized on schema_name, relkind // Note that when binding relkind as a string it must look like {"r", "v", ...} // (i.e., double quotes). Binding a binary list<string> element also works. static const char* kTablesQueryAll = "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " "AS reltype FROM pg_catalog.pg_class c " "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " "WHERE pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1 AND c.relkind = " "ANY($2)"; // Parameterized on schema_name, table_name static const char* kColumnsQueryAll = "SELECT attr.attname, attr.attnum, " "pg_catalog.col_description(cls.oid, attr.attnum) " "FROM pg_catalog.pg_attribute AS attr " "INNER JOIN pg_catalog.pg_class AS cls ON attr.attrelid = cls.oid " "INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " "WHERE attr.attnum > 0 AND NOT attr.attisdropped " "AND nsp.nspname LIKE $1 AND cls.relname LIKE $2"; // Parameterized on schema_name, table_name static const char* kConstraintsQueryAll = "WITH fk_unnest AS ( " " SELECT " " con.conname, " " 'FOREIGN KEY' AS contype, " " conrelid, " " UNNEST(con.conkey) AS conkey, " " confrelid, " " UNNEST(con.confkey) AS confkey " " FROM pg_catalog.pg_constraint AS con " " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid " " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " " WHERE con.contype = 'f' AND nsp.nspname = $1 " " AND cls.relname = $2 " "), " "fk_names AS ( " " SELECT " " fk_unnest.conname, " " fk_unnest.contype, " " fk_unnest.conkey, " " fk_unnest.confkey, " " attr.attname, " " fnsp.nspname AS fschema, " " fcls.relname AS ftable, " " fattr.attname AS fattname " " FROM fk_unnest " " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid " " INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid " " INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace" " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = " "fk_unnest.conkey " " AND attr.attrelid = fk_unnest.conrelid " " LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = " "fk_unnest.confkey " " AND fattr.attrelid = fk_unnest.confrelid " "), " "fkeys AS ( " " SELECT " " conname, " " contype, " " ARRAY_AGG(attname ORDER BY conkey) AS colnames, " " fschema, " " ftable, " " ARRAY_AGG(fattname ORDER BY confkey) AS fcolnames " " FROM fk_names " " GROUP BY " " conname, " " contype, " " fschema, " " ftable " "), " "other_constraints AS ( " " SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN " " 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, " " ARRAY_AGG(attr.attname) AS colnames " " FROM pg_catalog.pg_constraint AS con " " CROSS JOIN UNNEST(conkey) AS conkeys " " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid " " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys " " AND cls.oid = attr.attrelid " " WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname = $1 " " AND cls.relname = $2 " " GROUP BY conname, contype " ") " "SELECT " " conname, contype, colnames, fschema, ftable, fcolnames " "FROM fkeys " "UNION ALL " "SELECT " " conname, contype, colnames, NULL, NULL, NULL " "FROM other_constraints"; class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { public: explicit PostgresGetObjectsHelper(PGconn* conn) : current_database_(PQdb(conn)), all_catalogs_(conn, kCatalogQueryAll), some_catalogs_(conn, CatalogQuery()), all_schemas_(conn, kSchemaQueryAll), some_schemas_(conn, SchemaQuery()), all_tables_(conn, kTablesQueryAll), some_tables_(conn, TablesQuery()), all_columns_(conn, kColumnsQueryAll), some_columns_(conn, ColumnsQuery()), all_constraints_(conn, kConstraintsQueryAll), some_constraints_(conn, ConstraintsQuery()) {} // Allow Redshift to execute this query without constraints // TODO(paleolimbot): Investigate to see if we can simplify the constraits query so that // it works on both! void SetEnableConstraints(bool enable_constraints) { enable_constraints_ = enable_constraints; } Status Load(adbc::driver::GetObjectsDepth depth, std::optional<std::string_view> catalog_filter, std::optional<std::string_view> schema_filter, std::optional<std::string_view> table_filter, std::optional<std::string_view> column_filter, const std::vector<std::string_view>& table_types) override { return Status::Ok(); } Status LoadCatalogs(std::optional<std::string_view> catalog_filter) override { if (catalog_filter.has_value()) { UNWRAP_STATUS(some_catalogs_.Execute({std::string(*catalog_filter)})); next_catalog_ = some_catalogs_.Row(-1); } else { UNWRAP_STATUS(all_catalogs_.Execute()); next_catalog_ = all_catalogs_.Row(-1); } return Status::Ok(); }; Result<std::optional<std::string_view>> NextCatalog() override { next_catalog_ = next_catalog_.Next(); if (!next_catalog_.IsValid()) { return std::nullopt; } return next_catalog_[0].value(); } Status LoadSchemas(std::string_view catalog, std::optional<std::string_view> schema_filter) override { // PostgreSQL can only list for the current database if (catalog != current_database_) { return Status::Ok(); } if (schema_filter.has_value()) { UNWRAP_STATUS(some_schemas_.Execute({std::string(*schema_filter)})); next_schema_ = some_schemas_.Row(-1); } else { UNWRAP_STATUS(all_schemas_.Execute()); next_schema_ = all_schemas_.Row(-1); } return Status::Ok(); }; Result<std::optional<std::string_view>> NextSchema() override { next_schema_ = next_schema_.Next(); if (!next_schema_.IsValid()) { return std::nullopt; } return next_schema_[0].value(); } Status LoadTables(std::string_view catalog, std::string_view schema, std::optional<std::string_view> table_filter, const std::vector<std::string_view>& table_types) override { std::string table_types_bind = TableTypesArrayLiteral(table_types); if (table_filter.has_value()) { UNWRAP_STATUS(some_tables_.Execute( {std::string(schema), table_types_bind, std::string(*table_filter)})); next_table_ = some_tables_.Row(-1); } else { UNWRAP_STATUS(all_tables_.Execute({std::string(schema), table_types_bind})); next_table_ = all_tables_.Row(-1); } return Status::Ok(); }; Result<std::optional<Table>> NextTable() override { next_table_ = next_table_.Next(); if (!next_table_.IsValid()) { return std::nullopt; } return Table{next_table_[0].value(), next_table_[1].value()}; } Status LoadColumns(std::string_view catalog, std::string_view schema, std::string_view table, std::optional<std::string_view> column_filter) override { if (column_filter.has_value()) { UNWRAP_STATUS(some_columns_.Execute( {std::string(schema), std::string(table), std::string(*column_filter)})); next_column_ = some_columns_.Row(-1); } else { UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)})); next_column_ = all_columns_.Row(-1); } if (enable_constraints_) { if (column_filter.has_value()) { UNWRAP_STATUS(some_constraints_.Execute( {std::string(schema), std::string(table), std::string(*column_filter)})) next_constraint_ = some_constraints_.Row(-1); } else { UNWRAP_STATUS( all_constraints_.Execute({std::string(schema), std::string(table)})); next_constraint_ = all_constraints_.Row(-1); } } return Status::Ok(); }; Result<std::optional<Column>> NextColumn() override { next_column_ = next_column_.Next(); if (!next_column_.IsValid()) { return std::nullopt; } Column col; col.column_name = next_column_[0].value(); UNWRAP_RESULT(int64_t ordinal_position, next_column_[1].ParseInteger()); col.ordinal_position = static_cast<int32_t>(ordinal_position); if (!next_column_[2].is_null) { col.remarks = next_column_[2].value(); } return col; } Result<std::optional<Constraint>> NextConstraint() override { next_constraint_ = next_constraint_.Next(); if (!next_constraint_.IsValid()) { return std::nullopt; } Constraint out; out.name = next_constraint_[0].data; out.type = next_constraint_[1].data; UNWRAP_RESULT(constraint_fcolumn_names_, next_constraint_[2].ParseTextArray()); std::vector<std::string_view> fcolumn_names_view; for (const std::string& item : constraint_fcolumn_names_) { fcolumn_names_view.push_back(item); } out.column_names = std::move(fcolumn_names_view); if (out.type == "FOREIGN KEY") { assert(!next_constraint_[3].is_null); assert(!next_constraint_[3].is_null); assert(!next_constraint_[4].is_null); assert(!next_constraint_[5].is_null); out.usage = std::vector<ConstraintUsage>(); UNWRAP_RESULT(constraint_fkey_names_, next_constraint_[5].ParseTextArray()); for (const auto& item : constraint_fkey_names_) { ConstraintUsage usage; usage.catalog = current_database_; usage.schema = next_constraint_[3].data; usage.table = next_constraint_[4].data; usage.column = item; out.usage->push_back(usage); } } return out; } private: std::string current_database_; // Ready-to-Execute() queries PqResultHelper all_catalogs_; PqResultHelper some_catalogs_; PqResultHelper all_schemas_; PqResultHelper some_schemas_; PqResultHelper all_tables_; PqResultHelper some_tables_; PqResultHelper all_columns_; PqResultHelper some_columns_; PqResultHelper all_constraints_; PqResultHelper some_constraints_; // On Redshift, the constraints query fails bool enable_constraints_{true}; // Iterator state for the catalogs/schema/table/column queries PqResultRow next_catalog_; PqResultRow next_schema_; PqResultRow next_table_; PqResultRow next_column_; PqResultRow next_constraint_; // Owning variants required because the framework versions of these // are all based on string_view and the result helper can only parse arrays // into std::vector<std::string>. std::vector<std::string> constraint_fcolumn_names_; std::vector<std::string> constraint_fkey_names_; // Queries that are slightly modified versions of the generic queries that allow // the filter for that level to be passed through as a parameter. Defined here // because global strings should be const char* according to cpplint and using // the + operator to concatenate them is the most concise way to construct them. // Parameterized on catalog_name static std::string CatalogQuery() { return std::string(kCatalogQueryAll) + " WHERE datname = $1"; } // Parameterized on schema_name static std::string SchemaQuery() { return std::string(kSchemaQueryAll) + " AND nspname = $1"; } // Parameterized on schema_name, relkind, table_name static std::string TablesQuery() { return std::string(kTablesQueryAll) + " AND c.relname LIKE $3"; } // Parameterized on schema_name, table_name, column_name static std::string ColumnsQuery() { return std::string(kColumnsQueryAll) + " AND attr.attname LIKE $3"; } // Parameterized on schema_name, table_name, column_name static std::string ConstraintsQuery() { return std::string(kConstraintsQueryAll) + " WHERE conname LIKE $3"; } std::string TableTypesArrayLiteral(const std::vector<std::string_view>& table_types) { std::stringstream table_types_bind; table_types_bind << "{"; int table_types_bind_len = 0; if (table_types.empty()) { for (const auto& item : kPgTableTypes) { if (table_types_bind_len > 0) { table_types_bind << ", "; } table_types_bind << "\"" << item.second << "\""; table_types_bind_len++; } } else { for (auto type : table_types) { const auto maybe_item = kPgTableTypes.find(std::string(type)); if (maybe_item == kPgTableTypes.end()) { continue; } if (table_types_bind_len > 0) { table_types_bind << ", "; } table_types_bind << "\"" << maybe_item->second << "\""; table_types_bind_len++; } } table_types_bind << "}"; return table_types_bind.str(); } }; // A notice processor that does nothing with notices. In the future we can log // these, but this suppresses the default of printing to stderr. void SilentNoticeProcessor(void* /*arg*/, const char* /*message*/) {} } // namespace AdbcStatusCode PostgresConnection::Cancel(struct AdbcError* error) { // > errbuf must be a char array of size errbufsize (the recommended size is // > 256 bytes). // https://www.postgresql.org/docs/current/libpq-cancel.html char errbuf[256]; // > The return value is 1 if the cancel request was successfully dispatched // > and 0 if not. if (PQcancel(cancel_, errbuf, sizeof(errbuf)) != 1) { SetError(error, "[libpq] Failed to cancel operation: %s", errbuf); return ADBC_STATUS_UNKNOWN; } return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { if (autocommit_) { SetError(error, "%s", "[libpq] Cannot commit when autocommit is enabled"); return ADBC_STATUS_INVALID_STATE; } PGTransactionStatusType txn_status = PQtransactionStatus(conn_); if (txn_status == PQTRANS_IDLE) { // https://github.com/apache/arrow-adbc/issues/2673: don't rollback if the // transaction is idle, since it won't have any effect and PostgreSQL will // issue a warning on the server side return ADBC_STATUS_OK; } PGresult* result = PQexec(conn_, "COMMIT; BEGIN TRANSACTION"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { AdbcStatusCode code = SetError(error, result, "%s%s", "[libpq] Failed to commit: ", PQerrorMessage(conn_)); PQclear(result); return code; } PQclear(result); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!info_codes) { info_codes = kSupportedInfoCodes; info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); } std::vector<adbc::driver::InfoValue> infos; for (size_t i = 0; i < info_codes_length; i++) { switch (info_codes[i]) { case ADBC_INFO_VENDOR_NAME: infos.push_back({info_codes[i], std::string(VendorName())}); break; case ADBC_INFO_VENDOR_VERSION: { if (VendorName() == "Redshift") { const std::array<int, 3>& version = VendorVersion(); std::string version_string = std::to_string(version[0]) + "." + std::to_string(version[1]) + "." + std::to_string(version[2]); infos.push_back({info_codes[i], std::move(version_string)}); } else { // Gives a version in the form 140000 instead of 14.0.0 const char* stmt = "SHOW server_version_num"; auto result_helper = PqResultHelper{conn_, std::string(stmt)}; RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); return ADBC_STATUS_INTERNAL; } const char* server_version_num = (*it)[0].data; infos.push_back({info_codes[i], server_version_num}); } break; } case ADBC_INFO_DRIVER_NAME: infos.push_back({info_codes[i], "ADBC PostgreSQL Driver"}); break; case ADBC_INFO_DRIVER_VERSION: // TODO(lidavidm): fill in driver version infos.push_back({info_codes[i], "(unknown)"}); break; case ADBC_INFO_DRIVER_ARROW_VERSION: infos.push_back({info_codes[i], NANOARROW_VERSION}); break; case ADBC_INFO_DRIVER_ADBC_VERSION: infos.push_back({info_codes[i], ADBC_VERSION_1_1_0}); break; default: // Ignore continue; } } RAISE_ADBC(adbc::driver::MakeGetInfoStream(infos, out).ToAdbc(error)); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetObjects( struct AdbcConnection* connection, int c_depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_type, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) { PostgresGetObjectsHelper helper(conn_); helper.SetEnableConstraints(VendorName() != "Redshift"); const auto catalog_filter = catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; const auto schema_filter = db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; const auto table_filter = table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; const auto column_filter = column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt; std::vector<std::string_view> table_type_filter; while (table_type && *table_type) { if (*table_type) { table_type_filter.push_back(std::string_view(*table_type)); } table_type++; } using adbc::driver::GetObjectsDepth; GetObjectsDepth depth = GetObjectsDepth::kColumns; switch (c_depth) { case ADBC_OBJECT_DEPTH_CATALOGS: depth = GetObjectsDepth::kCatalogs; break; case ADBC_OBJECT_DEPTH_COLUMNS: depth = GetObjectsDepth::kColumns; break; case ADBC_OBJECT_DEPTH_DB_SCHEMAS: depth = GetObjectsDepth::kSchemas; break; case ADBC_OBJECT_DEPTH_TABLES: depth = GetObjectsDepth::kTables; break; default: return Status::InvalidArgument("[libpq] GetObjects: invalid depth ", c_depth) .ToAdbc(error); } auto status = BuildGetObjects(&helper, depth, catalog_filter, schema_filter, table_filter, column_filter, table_type_filter, out); RAISE_STATUS(error, helper.Close()); RAISE_STATUS(error, status); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, size_t* length, struct AdbcError* error) { std::string output; if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { output = PQdb(conn_); } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA()"}; RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } output = (*it)[0].data; } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED; } else if (std::strcmp(option, kConnectionOptionTransactionStatus.data()) == 0) { switch (PQtransactionStatus(conn_)) { case PQTRANS_IDLE: output = "idle"; break; case PQTRANS_ACTIVE: output = "active"; break; case PQTRANS_INTRANS: output = "intrans"; break; case PQTRANS_INERROR: output = "inerror"; break; case PQTRANS_UNKNOWN: default: output = "unknown"; break; } } else { return ADBC_STATUS_NOT_FOUND; } if (output.size() + 1 <= *length) { std::memcpy(value, output.c_str(), output.size() + 1); } *length = output.size() + 1; return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value, size_t* length, struct AdbcError* error) { return ADBC_STATUS_NOT_FOUND; } AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value, struct AdbcError* error) { return ADBC_STATUS_NOT_FOUND; } AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value, struct AdbcError* error) { return ADBC_STATUS_NOT_FOUND; } AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct ArrowArray* array, struct AdbcError* error) { // Set up schema auto uschema = nanoarrow::UniqueSchema(); { ArrowSchemaInit(uschema.get()); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), /*num_columns=*/2), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[0], NANOARROW_TYPE_STRING), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[0], "catalog_name"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[1], NANOARROW_TYPE_LIST), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[1], "catalog_db_schemas"), error); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema->children[1]->children[0], 2), error); uschema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; struct ArrowSchema* db_schema_schema = uschema->children[1]->children[0]; CHECK_NA(INTERNAL, ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_statistics"), error); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 5), error); db_schema_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; struct ArrowSchema* statistics_schema = db_schema_schema->children[1]->children[0]; CHECK_NA(INTERNAL, ArrowSchemaSetType(statistics_schema->children[0], NANOARROW_TYPE_STRING), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[0], "table_name"), error); statistics_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; CHECK_NA(INTERNAL, ArrowSchemaSetType(statistics_schema->children[1], NANOARROW_TYPE_STRING), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[1], "column_name"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(statistics_schema->children[2], NANOARROW_TYPE_INT16), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[2], "statistic_key"), error); statistics_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; CHECK_NA(INTERNAL, ArrowSchemaSetTypeUnion(statistics_schema->children[3], NANOARROW_TYPE_DENSE_UNION, 4), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[3], "statistic_value"), error); statistics_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; CHECK_NA(INTERNAL, ArrowSchemaSetType(statistics_schema->children[4], NANOARROW_TYPE_BOOL), error); CHECK_NA( INTERNAL, ArrowSchemaSetName(statistics_schema->children[4], "statistic_is_approximate"), error); statistics_schema->children[4]->flags &= ~ARROW_FLAG_NULLABLE; struct ArrowSchema* value_schema = statistics_schema->children[3]; CHECK_NA(INTERNAL, ArrowSchemaSetType(value_schema->children[0], NANOARROW_TYPE_INT64), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[0], "int64"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(value_schema->children[1], NANOARROW_TYPE_UINT64), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[1], "uint64"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(value_schema->children[2], NANOARROW_TYPE_DOUBLE), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[2], "float64"), error); CHECK_NA(INTERNAL, ArrowSchemaSetType(value_schema->children[3], NANOARROW_TYPE_BINARY), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[3], "binary"), error); } // Set up builders struct ArrowError na_error = {0}; CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), &na_error), &na_error, error); CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); struct ArrowArray* catalog_name_col = array->children[0]; struct ArrowArray* catalog_db_schemas_col = array->children[1]; struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0]; struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0]; struct ArrowArray* db_schema_statistics_col = catalog_db_schemas_items->children[1]; struct ArrowArray* db_schema_statistics_items = db_schema_statistics_col->children[0]; struct ArrowArray* statistics_table_name_col = db_schema_statistics_items->children[0]; struct ArrowArray* statistics_column_name_col = db_schema_statistics_items->children[1]; struct ArrowArray* statistics_key_col = db_schema_statistics_items->children[2]; struct ArrowArray* statistics_value_col = db_schema_statistics_items->children[3]; struct ArrowArray* statistics_is_approximate_col = db_schema_statistics_items->children[4]; // struct ArrowArray* value_int64_col = statistics_value_col->children[0]; // struct ArrowArray* value_uint64_col = statistics_value_col->children[1]; struct ArrowArray* value_float64_col = statistics_value_col->children[2]; // struct ArrowArray* value_binary_col = statistics_value_col->children[3]; // Query (could probably be massively improved) std::string query = R"( WITH class AS ( SELECT nspname, relname, reltuples FROM pg_namespace INNER JOIN pg_class ON pg_class.relnamespace = pg_namespace.oid ) SELECT tablename, attname, null_frac, avg_width, n_distinct, reltuples FROM pg_stats INNER JOIN class ON pg_stats.schemaname = class.nspname AND pg_stats.tablename = class.relname WHERE pg_stats.schemaname = $1 AND tablename LIKE $2 ORDER BY tablename )"; CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(PQdb(conn))), error); CHECK_NA(INTERNAL, ArrowArrayAppendString(db_schema_name_col, ArrowCharView(db_schema)), error); constexpr int8_t kStatsVariantFloat64 = 2; std::string prev_table; { PqResultHelper result_helper{conn, query}; RAISE_STATUS(error, result_helper.Execute({db_schema, table_name ? table_name : "%"})); for (PqResultRow row : result_helper) { auto reltuples = row[5].ParseDouble(); if (!reltuples) { SetError(error, "[libpq] Invalid double value in reltuples: '%s'", row[5].data); return ADBC_STATUS_INTERNAL; } if (std::strcmp(prev_table.c_str(), row[0].data) != 0) { CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_table_name_col, ArrowStringView{row[0].data, row[0].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendNull(statistics_column_name_col, 1), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_ROW_COUNT_KEY), error); CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, *reltuples), error); CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); prev_table = std::string(row[0].data, row[0].len); } auto null_frac = row[2].ParseDouble(); if (!null_frac) { SetError(error, "[libpq] Invalid double value in null_frac: '%s'", row[2].data); return ADBC_STATUS_INTERNAL; } CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_table_name_col, ArrowStringView{row[0].data, row[0].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_column_name_col, ArrowStringView{row[1].data, row[1].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_NULL_COUNT_KEY), error); CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, *null_frac * *reltuples), error); CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); auto average_byte_width = row[3].ParseDouble(); if (!average_byte_width) { SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[3].data); return ADBC_STATUS_INTERNAL; } CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_table_name_col, ArrowStringView{row[0].data, row[0].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_column_name_col, ArrowStringView{row[1].data, row[1].len}), error); CHECK_NA( INTERNAL, ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY), error); CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, *average_byte_width), error); CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); auto n_distinct = row[4].ParseDouble(); if (!n_distinct) { SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[4].data); return ADBC_STATUS_INTERNAL; } CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_table_name_col, ArrowStringView{row[0].data, row[0].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendString(statistics_column_name_col, ArrowStringView{row[1].data, row[1].len}), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_DISTINCT_COUNT_KEY), error); // > If greater than zero, the estimated number of distinct values in // > the column. If less than zero, the negative of the number of // > distinct values divided by the number of rows. // https://www.postgresql.org/docs/current/view-pg-stats.html CHECK_NA(INTERNAL, ArrowArrayAppendDouble( value_float64_col, *n_distinct > 0 ? *n_distinct : (std::fabs(*n_distinct) * *reltuples)), error); CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), error); CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); } } CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_col), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col), error); CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, error); uschema.move(schema); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, const char* db_schema, const char* table_name, bool approximate, struct ArrowArrayStream* out, struct AdbcError* error) { // Simplify our jobs here if (!approximate) { SetError(error, "[libpq] Exact statistics are not implemented"); return ADBC_STATUS_NOT_IMPLEMENTED; } else if (!db_schema) { SetError(error, "[libpq] Must request statistics for a single schema"); return ADBC_STATUS_NOT_IMPLEMENTED; } else if (catalog && std::strcmp(catalog, PQdb(conn_)) != 0) { SetError(error, "[libpq] Can only request statistics for current catalog"); return ADBC_STATUS_NOT_IMPLEMENTED; } struct ArrowSchema schema; std::memset(&schema, 0, sizeof(schema)); struct ArrowArray array; std::memset(&array, 0, sizeof(array)); AdbcStatusCode status = PostgresConnectionGetStatisticsImpl( conn_, db_schema, table_name, &schema, &array, error); if (status != ADBC_STATUS_OK) { if (schema.release) schema.release(&schema); if (array.release) array.release(&array); return status; } adbc::driver::MakeArrayStream(&schema, &array, out); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, struct ArrowArray* array, struct AdbcError* error) { auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/2), error); ArrowSchemaInit(uschema.get()->children[0]); CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "statistic_name"), error); uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; ArrowSchemaInit(uschema.get()->children[1]); CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get()->children[1], NANOARROW_TYPE_INT16), error); CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[1], "statistic_key"), error); uschema.get()->children[1]->flags &= ~ARROW_FLAG_NULLABLE; CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); uschema.move(schema); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* out, struct AdbcError* error) { // We don't support any extended statistics, just return an empty stream struct ArrowSchema schema; std::memset(&schema, 0, sizeof(schema)); struct ArrowArray array; std::memset(&array, 0, sizeof(array)); AdbcStatusCode status = PostgresConnectionGetStatisticNamesImpl(&schema, &array, error); if (status != ADBC_STATUS_OK) { if (schema.release) schema.release(&schema); if (array.release) array.release(&array); return status; } adbc::driver::MakeArrayStream(&schema, &array, out); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { AdbcStatusCode final_status = ADBC_STATUS_OK; char* quoted = PQescapeIdentifier(conn_, table_name, strlen(table_name)); std::string table_name_str(quoted); PQfreemem(quoted); if (db_schema != nullptr) { quoted = PQescapeIdentifier(conn_, db_schema, strlen(db_schema)); table_name_str = std::string(quoted) + "." + table_name_str; PQfreemem(quoted); } std::string query = "SELECT attname, atttypid " "FROM pg_catalog.pg_class AS cls " "INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid " "INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid " "WHERE attr.attnum >= 0 AND cls.oid = $1::regclass::oid " "ORDER BY attr.attnum"; std::vector<std::string> params = {table_name_str}; PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.c_str())}; RAISE_STATUS(error, result_helper.Execute(params)); auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), result_helper.NumRows()), error); int row_counter = 0; for (auto row : result_helper) { const char* colname = row[0].data; const Oid pg_oid = static_cast<uint32_t>(std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10)); PostgresType pg_type; if (type_resolver_->FindWithDefault(pg_oid, &pg_type) != NANOARROW_OK) { SetError(error, "%s%d%s%s%s%" PRIu32, "Error resolving type code for column #", row_counter + 1, " (\"", colname, "\") with oid ", pg_oid); final_status = ADBC_STATUS_NOT_IMPLEMENTED; break; } CHECK_NA(INTERNAL, pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter], std::string(VendorName())), error); row_counter++; } uschema.move(schema); return final_status; } AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connection, struct ArrowArrayStream* out, struct AdbcError* error) { std::vector<std::string> table_types; table_types.reserve(kPgTableTypes.size()); for (auto const& table_type : kPgTableTypes) { table_types.push_back(table_type.first); } RAISE_STATUS(error, adbc::driver::MakeTableTypesStream(table_types, out)); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, struct AdbcError* error) { if (!database || !database->private_data) { SetError(error, "[libpq] Must provide an initialized AdbcDatabase"); return ADBC_STATUS_INVALID_ARGUMENT; } database_ = *reinterpret_cast<std::shared_ptr<PostgresDatabase>*>(database->private_data); type_resolver_ = database_->type_resolver(); RAISE_ADBC(database_->Connect(&conn_, error)); cancel_ = PQgetCancel(conn_); if (!cancel_) { SetError(error, "[libpq] Could not initialize PGcancel"); return ADBC_STATUS_UNKNOWN; } std::ignore = PQsetNoticeProcessor(conn_, SilentNoticeProcessor, nullptr); for (const auto& [key, value] : post_init_options_) { RAISE_ADBC(SetOption(key.data(), value.data(), error)); } post_init_options_.clear(); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Release(struct AdbcError* error) { if (cancel_) { PQfreeCancel(cancel_); cancel_ = nullptr; } if (conn_) { return database_->Disconnect(&conn_, error); } return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Rollback(struct AdbcError* error) { if (autocommit_) { SetError(error, "%s", "[libpq] Cannot rollback when autocommit is enabled"); return ADBC_STATUS_INVALID_STATE; } PGTransactionStatusType txn_status = PQtransactionStatus(conn_); if (txn_status == PQTRANS_IDLE) { // https://github.com/apache/arrow-adbc/issues/2673: don't rollback if the // transaction is idle, since it won't have any effect and PostgreSQL will // issue a warning on the server side return ADBC_STATUS_OK; } PGresult* result = PQexec(conn_, "ROLLBACK AND CHAIN"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { SetError(error, "%s%s", "[libpq] Failed to rollback: ", PQerrorMessage(conn_)); PQclear(result); return ADBC_STATUS_IO; } PQclear(result); return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, struct AdbcError* error) { if (std::strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { bool autocommit = true; if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { autocommit = true; } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { autocommit = false; } else { SetError(error, "%s%s%s%s", "[libpq] Invalid value for option ", key, ": ", value); return ADBC_STATUS_INVALID_ARGUMENT; } if (!conn_) { post_init_options_.emplace_back(key, value); return ADBC_STATUS_OK; } if (autocommit != autocommit_) { const char* query = autocommit ? "COMMIT" : "BEGIN TRANSACTION"; PGresult* result = PQexec(conn_, query); if (PQresultStatus(result) != PGRES_COMMAND_OK) { SetError(error, "%s%s", "[libpq] Failed to update autocommit: ", PQerrorMessage(conn_)); PQclear(result); return ADBC_STATUS_IO; } PQclear(result); autocommit_ = autocommit; } return ADBC_STATUS_OK; } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { if (!conn_) { post_init_options_.emplace_back(key, value); return ADBC_STATUS_OK; } // PostgreSQL doesn't accept a parameter here char* value_esc = PQescapeIdentifier(conn_, value, strlen(value)); if (!value_esc) { SetError(error, "[libpq] Could not escape identifier: %s", PQerrorMessage(conn_)); return ADBC_STATUS_INTERNAL; } std::string query = fmt::format("SET search_path TO {}", value_esc); PQfreemem(value_esc); PqResultHelper result_helper{conn_, query}; RAISE_STATUS(error, result_helper.Execute()); return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } AdbcStatusCode PostgresConnection::SetOptionBytes(const char* key, const uint8_t* value, size_t length, struct AdbcError* error) { SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } AdbcStatusCode PostgresConnection::SetOptionDouble(const char* key, double value, struct AdbcError* error) { SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, struct AdbcError* error) { SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } std::string_view PostgresConnection::VendorName() { return database_->VendorName(); } const std::array<int, 3>& PostgresConnection::VendorVersion() { return database_->VendorVersion(); } } // namespace adbcpq