AdbcStatusCode Execute()

in c/driver/postgresql/statement.cc [317:509]


  AdbcStatusCode Execute(PGconn* conn, int64_t* rows_affected, struct AdbcError* error) {
    if (rows_affected) *rows_affected = 0;
    PGresult* result = nullptr;

    while (true) {
      Handle<struct ArrowArray> array;
      int res = bind->get_next(&bind.value, &array.value);
      if (res != 0) {
        SetError(error,
                 "[libpq] Failed to read next batch from stream of bind parameters: "
                 "(%d) %s %s",
                 res, std::strerror(res), bind->get_last_error(&bind.value));
        return ADBC_STATUS_IO;
      }
      if (!array->release) break;

      Handle<struct ArrowArrayView> array_view;
      // TODO: include error messages
      CHECK_NA(
          INTERNAL,
          ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr),
          error);
      CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr),
               error);

      for (int64_t row = 0; row < array->length; row++) {
        for (int64_t col = 0; col < array_view->n_children; col++) {
          if (ArrowArrayViewIsNull(array_view->children[col], row)) {
            param_values[col] = nullptr;
            continue;
          } else {
            param_values[col] = param_values_buffer.data() + param_values_offsets[col];
          }
          switch (bind_schema_fields[col].type) {
            case ArrowType::NANOARROW_TYPE_INT8: {
              const int16_t val =
                  array_view->children[col]->buffer_views[1].data.as_int8[row];
              const uint16_t value = ToNetworkInt16(val);
              std::memcpy(param_values[col], &value, sizeof(int16_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_INT16: {
              const uint16_t value = ToNetworkInt16(
                  array_view->children[col]->buffer_views[1].data.as_int16[row]);
              std::memcpy(param_values[col], &value, sizeof(int16_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_INT32: {
              const uint32_t value = ToNetworkInt32(
                  array_view->children[col]->buffer_views[1].data.as_int32[row]);
              std::memcpy(param_values[col], &value, sizeof(int32_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_INT64: {
              const int64_t value = ToNetworkInt64(
                  array_view->children[col]->buffer_views[1].data.as_int64[row]);
              std::memcpy(param_values[col], &value, sizeof(int64_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_FLOAT: {
              const uint32_t value = ToNetworkFloat4(
                  array_view->children[col]->buffer_views[1].data.as_float[row]);
              std::memcpy(param_values[col], &value, sizeof(uint32_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_DOUBLE: {
              const uint64_t value = ToNetworkFloat8(
                  array_view->children[col]->buffer_views[1].data.as_double[row]);
              std::memcpy(param_values[col], &value, sizeof(uint64_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_STRING:
            case ArrowType::NANOARROW_TYPE_BINARY: {
              const ArrowBufferView view =
                  ArrowArrayViewGetBytesUnsafe(array_view->children[col], row);
              // TODO: overflow check?
              param_lengths[col] = static_cast<int>(view.size_bytes);
              param_values[col] = const_cast<char*>(view.data.as_char);
              break;
            }
            case ArrowType::NANOARROW_TYPE_DATE32: {
              // 2000-01-01
              constexpr int32_t kPostgresDateEpoch = 10957;
              const int32_t raw_value =
                  array_view->children[col]->buffer_views[1].data.as_int32[row];
              if (raw_value < INT32_MIN + kPostgresDateEpoch) {
                SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1,
                         "('", bind_schema->children[col]->name, "') Row #", row + 1,
                         "has value which exceeds postgres date limits");
                return ADBC_STATUS_INVALID_ARGUMENT;
              }

              const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch);
              std::memcpy(param_values[col], &value, sizeof(int32_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
              int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row];

              // 2000-01-01 00:00:00.000000 in microseconds
              constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
              psnip_safe_bool overflow_safe = true;

              auto unit = bind_schema_fields[col].time_unit;

              switch (unit) {
                case NANOARROW_TIME_UNIT_SECOND:
                  overflow_safe = psnip_safe_int64_mul(&val, val, 1000000);
                  break;
                case NANOARROW_TIME_UNIT_MILLI:
                  overflow_safe = psnip_safe_int64_mul(&val, val, 1000);
                  break;
                case NANOARROW_TIME_UNIT_MICRO:
                  break;
                case NANOARROW_TIME_UNIT_NANO:
                  val /= 1000;
                  break;
              }

              if (!overflow_safe) {
                SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1,
                         " (' ", bind_schema->children[col]->name, " ') Row # ", row + 1,
                         " has value which exceeds postgres timestamp limits");

                return ADBC_STATUS_INVALID_ARGUMENT;
              }

              const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch);
              std::memcpy(param_values[col], &value, sizeof(int64_t));
              break;
            }
            case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
              struct ArrowInterval interval;
              ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO);
              ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval);

              const uint32_t months = ToNetworkInt32(interval.months);
              const uint32_t days = ToNetworkInt32(interval.days);
              const uint64_t ms = ToNetworkInt64(interval.ns / 1000);

              std::memcpy(param_values[col], &ms, sizeof(uint64_t));
              std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t));
              std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t),
                          &months, sizeof(uint32_t));
              break;
            }
            default:
              SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
                       bind_schema->children[col]->name,
                       "') has unsupported type for ingestion ",
                       ArrowTypeString(bind_schema_fields[col].type));
              return ADBC_STATUS_NOT_IMPLEMENTED;
          }
        }

        result = PQexecPrepared(conn, /*stmtName=*/"",
                                /*nParams=*/bind_schema->n_children, param_values.data(),
                                param_lengths.data(), param_formats.data(),
                                /*resultFormat=*/0 /*text*/);

        if (PQresultStatus(result) != PGRES_COMMAND_OK) {
          SetError(error, "%s%s", "[libpq] Failed to execute prepared statement: ",
                   PQerrorMessage(conn));
          PQclear(result);
          return ADBC_STATUS_IO;
        }

        PQclear(result);
      }
      if (rows_affected) *rows_affected += array->length;

      if (has_tz_field) {
        std::string reset_query = "SET TIME ZONE '" + tz_setting + "'";
        PGresult* reset_tz_result = PQexec(conn, reset_query.c_str());
        if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) {
          SetError(error, "[libpq] Failed to reset time zone: %s", PQerrorMessage(conn));
          PQclear(reset_tz_result);
          return ADBC_STATUS_IO;
        }
        PQclear(reset_tz_result);

        PGresult* commit_result = PQexec(conn, "COMMIT");
        if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) {
          SetError(error, "[libpq] Failed to commit transaction: %s",
                   PQerrorMessage(conn));
          PQclear(commit_result);
          return ADBC_STATUS_IO;
        }
        PQclear(commit_result);
      }
    }
    return ADBC_STATUS_OK;
  }