go/adbc/sqldriver/driver.go (576 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package sqldriver
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
"github.com/apache/arrow-go/v18/arrow/memory"
)
func getIsolationlevel(lvl sql.IsolationLevel) adbc.OptionIsolationLevel {
switch lvl {
case sql.LevelDefault:
return adbc.LevelDefault
case sql.LevelReadUncommitted:
return adbc.LevelReadUncommitted
case sql.LevelReadCommitted:
return adbc.LevelReadCommitted
case sql.LevelRepeatableRead:
return adbc.LevelRepeatableRead
case sql.LevelSnapshot:
return adbc.LevelSnapshot
case sql.LevelSerializable:
return adbc.LevelSerializable
case sql.LevelLinearizable:
return adbc.LevelLinearizable
}
return ""
}
func parseConnectStr(str string) (ret map[string]string, err error) {
ret = make(map[string]string)
for _, kv := range strings.Split(str, ";") {
parsed := strings.SplitN(kv, "=", 2)
if len(parsed) != 2 {
return nil, &adbc.Error{
Msg: "invalid format for connection string",
Code: adbc.StatusInvalidArgument,
}
}
ret[strings.TrimSpace(parsed[0])] = strings.TrimSpace(parsed[1])
}
return
}
type connector struct {
db adbc.Database
drv adbc.Driver
}
// Connect returns a connection to the database. Connect may
// return a cached connection (one previously closed), but doing
// so is unnecessary; the sql package maintains a pool of idle
// connections for efficient re-use.
//
// The provided context.Context is for dialing purposes only
// (see net.DialContext) and should not be stored or used for
// other purposes. A default timeout should still be used when
// dialing as a connection pool may call Connect asynchronously
// to any query.
//
// The returned connection is only used by one goroutine at a time.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
cnxn, err := c.db.Open(ctx)
if err != nil {
return nil, err
}
return &conn{Conn: cnxn, drv: c.db}, nil
}
// Driver returns the underlying Driver of the connector,
// mainly to maintain compatibility with the Driver method on sql.DB
func (c *connector) Driver() driver.Driver { return Driver{c.drv} }
type Driver struct {
Driver adbc.Driver
}
// Open returns a new connection to the database. The name
// should be semi-colon separated key-value pairs of the form:
// key=value;key2=value2;.....
//
// Open may return a cached connection (one previously closed),
// but doing so is unnecessary; the sql package maintains a pool
// of idle connections for efficient re-use.
//
// The returned connection is only used by one goroutine at a time.
func (d Driver) Open(name string) (driver.Conn, error) {
connector, err := d.OpenConnector(name)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}
// OpenConnector expects the same format as driver.Open
func (d Driver) OpenConnector(name string) (driver.Connector, error) {
opts, err := parseConnectStr(name)
if err != nil {
return nil, err
}
db, err := d.Driver.NewDatabase(opts)
if err != nil {
return nil, err
}
return &connector{db, d.Driver}, nil
}
type ctxOptsKey struct{}
func SetOptionsInCtx(ctx context.Context, opts map[string]string) context.Context {
return context.WithValue(ctx, ctxOptsKey{}, opts)
}
func GetOptionsFromCtx(ctx context.Context) map[string]string {
v, ok := ctx.Value(ctxOptsKey{}).(map[string]string)
if !ok {
return nil
}
return v
}
// conn is a connection to a database. It is not used concurrently by
// multiple goroutines. It is assumed to be stateful.
type conn struct {
Conn adbc.Connection
drv adbc.Database
}
// Close invalidates and potentially stops any current prepared
// statements and transactions, marking this connection as no longer
// in use.
//
// Because the sql package maintains a free pool of connections and
// only calls Close when there's a surplus of idle connections,
// it shouldn't be necessary for drivers to do their own connection
// caching.
//
// Drivers must ensure all network calls made by Close do not block
// indefinitely (e.g. apply a timeout)
func (c *conn) Close() error {
return c.Conn.Close()
}
func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
namedValues := make([]driver.NamedValue, len(values))
for i, value := range values {
namedValues[i] = driver.NamedValue{
// nb: Name field is optional
Ordinal: i,
Value: value,
}
}
return c.QueryContext(context.Background(), query, namedValues)
}
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
s, err := c.Conn.NewStatement()
if err != nil {
return nil, err
}
if err = s.SetSqlQuery(query); err != nil {
return nil, errors.Join(err, s.Close())
}
return (&stmt{stmt: s}).QueryContext(ctx, args)
}
// Begin exists to fulfill the Conn interface, but will return an error.
// Instead, the ConnBeginTx interface is implemented instead.
//
// Deprecated
func (c *conn) Begin() (driver.Tx, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
}
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if postopt, ok := c.Conn.(adbc.PostInitOptions); ok {
if err := postopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled); err != nil {
return nil, err
}
isolationLevel := getIsolationlevel(sql.IsolationLevel(opts.Isolation))
if isolationLevel == "" {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
}
if err := postopt.SetOption(adbc.OptionKeyIsolationLevel, string(isolationLevel)); err != nil {
return nil, err
}
if opts.ReadOnly {
if err := postopt.SetOption(adbc.OptionKeyReadOnly, adbc.OptionValueEnabled); err != nil {
return nil, err
}
}
return tx{ctx: ctx, conn: c.Conn}, nil
}
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
}
// Prepare returns a prepared statement, bound to this connection.
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
// PrepareContext returns a prepared statement, bound to this connection.
// Context is for the preparation of the statement. The statement must not
// store the context within the statement itself.
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
s, err := c.Conn.NewStatement()
if err != nil {
return nil, err
}
if err := s.SetSqlQuery(query); err != nil {
return nil, errors.Join(err, s.Close())
}
if err := s.Prepare(ctx); err != nil {
return nil, errors.Join(err, s.Close())
}
paramSchema, err := s.GetParameterSchema()
var adbcErr adbc.Error
if errors.As(err, &adbcErr) {
if adbcErr.Code != adbc.StatusNotImplemented {
return nil, err
}
}
return &stmt{stmt: s, paramSchema: paramSchema}, nil
}
type tx struct {
ctx context.Context
conn adbc.Connection
}
func (t tx) Commit() error {
if err := t.conn.Commit(t.ctx); err != nil {
return err
}
return t.conn.(adbc.PostInitOptions).SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)
}
func (t tx) Rollback() error {
if err := t.conn.Rollback(t.ctx); err != nil {
return err
}
return t.conn.(adbc.PostInitOptions).SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)
}
type stmt struct {
stmt adbc.Statement
paramSchema *arrow.Schema
}
func (s *stmt) Close() error {
return s.stmt.Close()
}
func (s *stmt) NumInput() int {
if s.paramSchema == nil {
return -1
}
return len(s.paramSchema.Fields())
}
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}
func checkType[T any](val any) bool {
switch val.(type) {
case T, *T:
default:
return false
}
return true
}
func isCorrectParamType(typ arrow.Type, val driver.Value) bool {
switch typ {
case arrow.BINARY:
return checkType[[]byte](val)
case arrow.BOOL:
return checkType[bool](val)
case arrow.INT8:
return checkType[int8](val)
case arrow.UINT8:
return checkType[uint8](val)
case arrow.INT16:
return checkType[int16](val)
case arrow.UINT16:
return checkType[uint16](val)
case arrow.INT32:
return checkType[int32](val)
case arrow.UINT32:
return checkType[uint32](val)
case arrow.INT64:
return checkType[int64](val)
case arrow.UINT64:
return checkType[uint64](val)
case arrow.STRING:
return checkType[string](val)
case arrow.FLOAT32:
return checkType[float32](val)
case arrow.FLOAT64:
return checkType[float64](val)
case arrow.DATE32:
return checkType[arrow.Date32](val)
case arrow.DATE64:
return checkType[arrow.Date64](val)
case arrow.TIME32:
return checkType[arrow.Time32](val)
case arrow.TIME64:
return checkType[arrow.Time64](val)
case arrow.TIMESTAMP:
return checkType[arrow.Timestamp](val)
case arrow.DECIMAL128:
return checkType[decimal128.Num](val)
case arrow.DECIMAL256:
return checkType[decimal256.Num](val)
}
// TODO: add more types here
return true
}
// this will check the value against the parameter schema if it
// exists, and if the type is non-NA, will enforce the correct type.
func (s *stmt) CheckNamedValue(val *driver.NamedValue) error {
if s.paramSchema == nil {
// we don't know the parameter schema, so we can't validate
// the arguments.
return driver.ErrSkip
}
var field arrow.Field
if val.Name != "" {
fields, exists := s.paramSchema.FieldsByName(val.Name)
if !exists {
return &adbc.Error{
Msg: "could not find parameter named '" + val.Name + "'",
Code: adbc.StatusInvalidArgument,
}
}
field = fields[0]
} else {
if val.Ordinal > len(s.paramSchema.Fields()) {
return &adbc.Error{
Msg: "too many parameters passed for query",
Code: adbc.StatusInvalidArgument,
}
}
// val.Ordinal is 1-based
field = s.paramSchema.Fields()[val.Ordinal-1]
}
if field.Type.ID() == arrow.NULL {
return nil
}
if !isCorrectParamType(field.Type.ID(), val.Value) {
return &adbc.Error{
Code: adbc.StatusInvalidArgument,
Msg: "expected parameter of type " + field.Type.String(),
}
}
return nil
}
func arrFromVal(val any) arrow.Array {
var (
buffers = make([]*memory.Buffer, 2)
dt arrow.DataType
)
switch v := val.(type) {
case bool:
dt = arrow.FixedWidthTypes.Boolean
buffers[1] = memory.NewBufferBytes((*[1]byte)(unsafe.Pointer(&v))[:])
case int8:
dt = arrow.PrimitiveTypes.Int8
buffers[1] = memory.NewBufferBytes((*[1]byte)(unsafe.Pointer(&v))[:])
case uint8:
dt = arrow.PrimitiveTypes.Uint8
buffers[1] = memory.NewBufferBytes((*[1]byte)(unsafe.Pointer(&v))[:])
case int16:
dt = arrow.PrimitiveTypes.Int16
buffers[1] = memory.NewBufferBytes((*[2]byte)(unsafe.Pointer(&v))[:])
case uint16:
dt = arrow.PrimitiveTypes.Uint16
buffers[1] = memory.NewBufferBytes((*[2]byte)(unsafe.Pointer(&v))[:])
case int32:
dt = arrow.PrimitiveTypes.Int32
buffers[1] = memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
case uint32:
dt = arrow.PrimitiveTypes.Uint32
buffers[1] = memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
case int64:
dt = arrow.PrimitiveTypes.Int64
buffers[1] = memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
case uint64:
dt = arrow.PrimitiveTypes.Uint64
buffers[1] = memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
case float32:
dt = arrow.PrimitiveTypes.Float32
buffers[1] = memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
case float64:
dt = arrow.PrimitiveTypes.Float64
buffers[1] = memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
case arrow.Date32:
dt = arrow.PrimitiveTypes.Date32
buffers[1] = memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
case arrow.Date64:
dt = arrow.PrimitiveTypes.Date64
buffers[1] = memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
case []byte:
dt = arrow.BinaryTypes.Binary
buffers[1] = memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
buffers = append(buffers, memory.NewBufferBytes(v))
case string:
dt = arrow.BinaryTypes.String
buffers[1] = memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
buf := unsafe.Slice(unsafe.StringData(v), len(v))
buffers = append(buffers, memory.NewBufferBytes(buf))
default:
panic(fmt.Sprintf("unsupported type %T", val))
}
for _, b := range buffers {
if b != nil {
defer b.Release()
}
}
data := array.NewData(dt, 1, buffers, nil, 0, 0)
defer data.Release()
return array.MakeFromData(data)
}
func createBoundRecord(values []driver.NamedValue, schema *arrow.Schema) arrow.Record {
fields := make([]arrow.Field, len(values))
cols := make([]arrow.Array, len(values))
if schema == nil {
for _, v := range values {
f := &fields[v.Ordinal-1]
if v.Name == "" {
f.Name = strconv.Itoa(v.Ordinal)
} else {
f.Name = v.Name
}
arr := arrFromVal(v.Value)
defer arr.Release()
f.Type = arr.DataType()
cols[v.Ordinal-1] = arr
}
return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
}
for _, v := range values {
var idx int
var name string
if v.Name != "" {
idx = schema.FieldIndices(v.Name)[0]
name = v.Name
} else {
idx = v.Ordinal - 1
name = strconv.Itoa(idx)
}
f := &fields[idx]
f.Name = name
arr := arrFromVal(v.Value)
defer arr.Release()
f.Type = arr.DataType()
cols[idx] = arr
}
return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if len(args) > 0 {
if err := s.stmt.Bind(ctx, createBoundRecord(args, s.paramSchema)); err != nil {
return nil, err
}
}
affected, err := s.stmt.ExecuteUpdate(ctx)
if err != nil {
return nil, err
}
return driver.RowsAffected(affected), nil
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if len(args) > 0 {
if err := s.stmt.Bind(ctx, createBoundRecord(args, s.paramSchema)); err != nil {
return nil, err
}
}
rdr, affected, err := s.stmt.ExecuteQuery(ctx)
if err != nil {
return nil, err
}
return &rows{rdr: rdr, rowsAffected: affected, stmt: s}, nil
}
type rows struct {
rdr array.RecordReader
curRow int64
curRecord arrow.Record
rowsAffected int64
stmt *stmt
}
func (r *rows) Columns() (out []string) {
out = make([]string, len(r.rdr.Schema().Fields()))
for i, f := range r.rdr.Schema().Fields() {
out[i] = f.Name
}
return
}
func (r *rows) Close() error {
if r.curRecord != nil {
r.curRecord = nil
}
r.rdr.Release()
r.rdr = nil
r.stmt = nil
return nil
}
func (r *rows) Next(dest []driver.Value) error {
if r.curRecord != nil && r.curRow == r.curRecord.NumRows() {
r.curRecord = nil
}
for r.curRecord == nil {
if !r.rdr.Next() {
if err := r.rdr.Err(); err != nil {
return err
}
return io.EOF
}
r.curRecord = r.rdr.Record()
r.curRow = 0
if r.curRecord.NumRows() == 0 {
r.curRecord = nil
}
}
for i, col := range r.curRecord.Columns() {
if col.IsNull(int(r.curRow)) {
dest[i] = nil
continue
}
if colUnion, ok := col.(array.Union); ok {
col = colUnion.Field(colUnion.ChildID(int(r.curRow)))
}
switch col := col.(type) {
case *array.Boolean:
dest[i] = col.Value(int(r.curRow))
case *array.Int8:
dest[i] = col.Value(int(r.curRow))
case *array.Uint8:
dest[i] = col.Value(int(r.curRow))
case *array.Int16:
dest[i] = col.Value(int(r.curRow))
case *array.Uint16:
dest[i] = col.Value(int(r.curRow))
case *array.Int32:
dest[i] = col.Value(int(r.curRow))
case *array.Uint32:
dest[i] = col.Value(int(r.curRow))
case *array.Int64:
dest[i] = col.Value(int(r.curRow))
case *array.Uint64:
dest[i] = col.Value(int(r.curRow))
case *array.Float32:
dest[i] = col.Value(int(r.curRow))
case *array.Float64:
dest[i] = col.Value(int(r.curRow))
case *array.String:
dest[i] = col.Value(int(r.curRow))
case *array.LargeString:
dest[i] = col.Value(int(r.curRow))
case *array.Binary:
dest[i] = col.Value(int(r.curRow))
case *array.LargeBinary:
dest[i] = col.Value(int(r.curRow))
case *array.Date32:
dest[i] = col.Value(int(r.curRow)).ToTime()
case *array.Date64:
dest[i] = col.Value(int(r.curRow)).ToTime()
case *array.Time32:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time32Type).Unit)
case *array.Time64:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
case *array.Timestamp:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
case *array.Decimal128:
dest[i] = col.Value(int(r.curRow))
case *array.Decimal256:
dest[i] = col.Value(int(r.curRow))
default:
return &adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: "not yet implemented populating from columns of type " + col.DataType().String(),
}
}
}
r.curRow++
return nil
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
return r.rdr.Schema().Field(index).Type.String()
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
return r.rdr.Schema().Field(index).Nullable, true
}
func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
typ := r.rdr.Schema().Field(index).Type
switch dt := typ.(type) {
case *arrow.Decimal128Type:
return int64(dt.Precision), int64(dt.Scale), true
case *arrow.Decimal256Type:
return int64(dt.Precision), int64(dt.Scale), true
}
return 0, 0, false
}
func (r *rows) ColumnTypeScanType(index int) reflect.Type {
switch r.rdr.Schema().Field(index).Type.ID() {
case arrow.BOOL:
return reflect.TypeOf(false)
case arrow.INT8:
return reflect.TypeOf(int8(0))
case arrow.UINT8:
return reflect.TypeOf(uint8(0))
case arrow.INT16:
return reflect.TypeOf(int16(0))
case arrow.UINT16:
return reflect.TypeOf(uint16(0))
case arrow.INT32:
return reflect.TypeOf(int32(0))
case arrow.UINT32:
return reflect.TypeOf(uint32(0))
case arrow.INT64:
return reflect.TypeOf(int64(0))
case arrow.UINT64:
return reflect.TypeOf(uint64(0))
case arrow.FLOAT32:
return reflect.TypeOf(float32(0))
case arrow.FLOAT64:
return reflect.TypeOf(float64(0))
case arrow.DECIMAL128:
return reflect.TypeOf(decimal128.Num{})
case arrow.DECIMAL256:
return reflect.TypeOf(decimal256.Num{})
case arrow.BINARY:
return reflect.TypeOf([]byte{})
case arrow.STRING:
return reflect.TypeOf(string(""))
case arrow.TIME32, arrow.TIME64, arrow.DATE32, arrow.DATE64, arrow.TIMESTAMP:
return reflect.TypeOf(time.Time{})
}
return nil
}