go/adbc/driver/snowflake/statement.go (425 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package snowflake
import (
"context"
"database/sql/driver"
"fmt"
"strconv"
"strings"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v13/arrow"
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/constraints"
)
const (
OptionStatementQueueSize = "adbc.rpc.result_queue_size"
)
type statement struct {
cnxn *cnxn
alloc memory.Allocator
queueSize int
query string
targetTable string
append bool
bound arrow.Record
streamBind array.RecordReader
}
// Close releases any relevant resources associated with this statement
// and closes it (particularly if it is a prepared statement).
//
// A statement instance should not be used after Close is called.
func (st *statement) Close() error {
if st.cnxn == nil {
return adbc.Error{
Msg: "statement already closed",
Code: adbc.StatusInvalidState}
}
if st.bound != nil {
st.bound.Release()
st.bound = nil
} else if st.streamBind != nil {
st.streamBind.Release()
st.streamBind = nil
}
st.cnxn = nil
return nil
}
// SetOption sets a string option on this statement
func (st *statement) SetOption(key string, val string) error {
switch key {
case adbc.OptionKeyIngestTargetTable:
st.query = ""
st.targetTable = val
case adbc.OptionKeyIngestMode:
switch val {
case adbc.OptionValueIngestModeAppend:
st.append = true
case adbc.OptionValueIngestModeCreate:
st.append = false
default:
return adbc.Error{
Msg: fmt.Sprintf("invalid statement option %s=%s", key, val),
Code: adbc.StatusInvalidArgument,
}
}
case OptionStatementQueueSize:
sz, err := strconv.Atoi(val)
if err != nil {
return adbc.Error{
Msg: fmt.Sprintf("could not parse '%s' as int for option '%s'", val, key),
Code: adbc.StatusInvalidArgument,
}
}
st.queueSize = sz
default:
return adbc.Error{
Msg: fmt.Sprintf("invalid statement option %s=%s", key, val),
Code: adbc.StatusInvalidArgument,
}
}
return nil
}
// SetSqlQuery sets the query string to be executed.
//
// The query can then be executed with any of the Execute methods.
// For queries expected to be executed repeatedly, Prepare should be
// called before execution.
func (st *statement) SetSqlQuery(query string) error {
st.query = query
st.targetTable = ""
return nil
}
func toSnowflakeType(dt arrow.DataType) string {
switch dt.ID() {
case arrow.EXTENSION:
return toSnowflakeType(dt.(arrow.ExtensionType).StorageType())
case arrow.DICTIONARY:
return toSnowflakeType(dt.(*arrow.DictionaryType).ValueType)
case arrow.RUN_END_ENCODED:
return toSnowflakeType(dt.(*arrow.RunEndEncodedType).Encoded())
case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64,
arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64:
return "integer"
case arrow.FLOAT32, arrow.FLOAT16, arrow.FLOAT64:
return "double"
case arrow.DECIMAL, arrow.DECIMAL256:
dec := dt.(arrow.DecimalType)
return fmt.Sprintf("NUMERIC(%d,%d)", dec.GetPrecision(), dec.GetScale())
case arrow.STRING, arrow.LARGE_STRING:
return "text"
case arrow.BINARY, arrow.LARGE_BINARY:
return "binary"
case arrow.FIXED_SIZE_BINARY:
fsb := dt.(*arrow.FixedSizeBinaryType)
return fmt.Sprintf("binary(%d)", fsb.ByteWidth)
case arrow.BOOL:
return "boolean"
case arrow.TIME32, arrow.TIME64:
t := dt.(arrow.TemporalWithUnit)
prec := int(t.TimeUnit()) * 3
return fmt.Sprintf("time(%d)", prec)
case arrow.DATE32, arrow.DATE64:
return "date"
case arrow.TIMESTAMP:
ts := dt.(*arrow.TimestampType)
prec := int(ts.Unit) * 3
if ts.TimeZone == "" {
return fmt.Sprintf("timestamp_ntz(%d)", prec)
}
return fmt.Sprintf("timestamp_tz(%d)", prec)
case arrow.DENSE_UNION, arrow.SPARSE_UNION:
return "variant"
case arrow.LIST, arrow.LARGE_LIST, arrow.FIXED_SIZE_LIST:
return "array"
case arrow.STRUCT, arrow.MAP:
return "object"
}
return ""
}
func (st *statement) initIngest(ctx context.Context) (string, error) {
var (
createBldr, insertBldr strings.Builder
)
createBldr.WriteString("CREATE TABLE ")
createBldr.WriteString(st.targetTable)
createBldr.WriteString(" (")
insertBldr.WriteString("INSERT INTO ")
insertBldr.WriteString(st.targetTable)
insertBldr.WriteString(" VALUES (")
var schema *arrow.Schema
if st.bound != nil {
schema = st.bound.Schema()
} else {
schema = st.streamBind.Schema()
}
for i, f := range schema.Fields() {
if i != 0 {
insertBldr.WriteString(", ")
createBldr.WriteString(", ")
}
createBldr.WriteString(strconv.Quote(f.Name))
createBldr.WriteString(" ")
ty := toSnowflakeType(f.Type)
if ty == "" {
return "", adbc.Error{
Msg: fmt.Sprintf("unimplemented type conversion for field %s, arrow type: %s", f.Name, f.Type),
Code: adbc.StatusNotImplemented,
}
}
createBldr.WriteString(ty)
if !f.Nullable {
createBldr.WriteString(" NOT NULL")
}
insertBldr.WriteString("?")
}
createBldr.WriteString(")")
insertBldr.WriteString(")")
if !st.append {
// create the table!
createQuery := createBldr.String()
_, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil)
if err != nil {
return "", errToAdbcErr(adbc.StatusInternal, err)
}
}
return insertBldr.String(), nil
}
type nativeArrowArr[T string | []byte] interface {
arrow.Array
Value(int) T
}
func convToArr[T string | []byte](arr nativeArrowArr[T]) interface{} {
if arr.Len() == 1 {
if arr.IsNull(0) {
return nil
}
return arr.Value(0)
}
v := make([]interface{}, arr.Len())
for i := 0; i < arr.Len(); i++ {
if arr.IsNull(i) {
continue
}
v[i] = arr.Value(i)
}
return gosnowflake.Array(&v)
}
func convMarshal(arr arrow.Array) interface{} {
if arr.Len() == 0 {
if arr.IsNull(0) {
return nil
}
return arr.ValueStr(0)
}
v := make([]interface{}, arr.Len())
for i := 0; i < arr.Len(); i++ {
if arr.IsNull(i) {
continue
}
v[i] = arr.ValueStr(i)
}
return gosnowflake.Array(&v)
}
// snowflake driver bindings only support specific types
// int/int32/int64/float64/float32/bool/string/byte/time
// so we have to cast anything else appropriately
func convToSlice[T, O constraints.Integer | constraints.Float](arr arrow.Array, vals []T) interface{} {
if arr.Len() == 1 {
if arr.IsNull(0) {
return nil
}
return vals[0]
}
out := make([]interface{}, arr.Len())
for i, v := range vals {
if arr.IsNull(i) {
continue
}
out[i] = O(v)
}
return gosnowflake.Array(&out)
}
func getQueryArg(arr arrow.Array) interface{} {
switch arr := arr.(type) {
case *array.Int8:
v := arr.Int8Values()
return convToSlice[int8, int32](arr, v)
case *array.Uint8:
v := arr.Uint8Values()
return convToSlice[uint8, int32](arr, v)
case *array.Int16:
v := arr.Int16Values()
return convToSlice[int16, int32](arr, v)
case *array.Uint16:
v := arr.Uint16Values()
return convToSlice[uint16, int32](arr, v)
case *array.Int32:
v := arr.Int32Values()
return convToSlice[int32, int32](arr, v)
case *array.Uint32:
v := arr.Uint32Values()
return convToSlice[uint32, int64](arr, v)
case *array.Int64:
v := arr.Int64Values()
return convToSlice[int64, int64](arr, v)
case *array.Uint64:
v := arr.Uint64Values()
return convToSlice[uint64, int64](arr, v)
case *array.Float32:
v := arr.Float32Values()
return convToSlice[float32, float64](arr, v)
case *array.Float64:
v := arr.Float64Values()
return convToSlice[float64, float64](arr, v)
case *array.LargeBinary:
return convToArr[[]byte](arr)
case *array.Binary:
return convToArr[[]byte](arr)
case *array.LargeString:
return convToArr[string](arr)
case *array.String:
return convToArr[string](arr)
default:
// default convert to array of strings and pass to snowflake driver
// not the most efficient, but snowflake doesn't really give a better
// route currently short of writing everything out to a Parquet file
// and then uploading that (which might be preferable)
return convMarshal(arr)
}
}
func (st *statement) executeIngest(ctx context.Context) (int64, error) {
if st.streamBind == nil && st.bound == nil {
return -1, adbc.Error{
Msg: "must call Bind before bulk ingestion",
Code: adbc.StatusInvalidState,
}
}
insertQuery, err := st.initIngest(ctx)
if err != nil {
return -1, err
}
// if the ingestion is large enough it might make more sense to
// write this out to a temporary file / stage / etc. and use
// the snowflake bulk loader that way.
//
// on the other hand, according to the documentation,
// https://pkg.go.dev/github.com/snowflakedb/gosnowflake#hdr-Batch_Inserts_and_Binding_Parameters
// snowflake's internal driver work should already be doing this.
var n int64
exec := func(rec arrow.Record, args []driver.NamedValue) error {
for i, c := range rec.Columns() {
args[i].Ordinal = i
args[i].Value = getQueryArg(c)
}
r, err := st.cnxn.cn.ExecContext(ctx, insertQuery, args)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
rows, err := r.RowsAffected()
if err == nil {
n += rows
}
return nil
}
if st.bound != nil {
defer func() {
st.bound.Release()
st.bound = nil
}()
args := make([]driver.NamedValue, len(st.bound.Schema().Fields()))
return n, exec(st.bound, args)
}
defer func() {
st.streamBind.Release()
st.streamBind = nil
}()
args := make([]driver.NamedValue, len(st.streamBind.Schema().Fields()))
for st.streamBind.Next() {
rec := st.streamBind.Record()
if err := exec(rec, args); err != nil {
return n, err
}
}
return n, nil
}
// ExecuteQuery executes the current query or prepared statement
// and returnes a RecordReader for the results along with the number
// of rows affected if known, otherwise it will be -1.
//
// This invalidates any prior result sets on this statement.
func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error) {
if st.targetTable != "" {
n, err := st.executeIngest(ctx)
return nil, n, err
}
if st.query == "" {
return nil, -1, adbc.Error{
Msg: "cannot execute without a query",
Code: adbc.StatusInvalidState,
}
}
// for a bound stream reader we'd need to implement something to
// concatenate RecordReaders which doesn't exist yet. let's put
// that off for now.
if st.streamBind != nil || st.bound != nil {
return nil, -1, adbc.Error{
Msg: "executing non-bulk ingest with bound params not yet implemented",
Code: adbc.StatusNotImplemented,
}
}
loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
if err != nil {
return nil, -1, errToAdbcErr(adbc.StatusInternal, err)
}
rdr, err := newRecordReader(ctx, st.alloc, loader, st.queueSize)
nrec := loader.TotalRows()
return rdr, nrec, err
}
// ExecuteUpdate executes a statement that does not generate a result
// set. It returns the number of rows affected if known, otherwise -1.
func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
if st.targetTable != "" {
return st.executeIngest(ctx)
}
if st.query == "" {
return -1, adbc.Error{
Msg: "cannot execute without a query",
Code: adbc.StatusInvalidState,
}
}
r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
if err != nil {
return -1, errToAdbcErr(adbc.StatusIO, err)
}
n, err := r.RowsAffected()
if err != nil {
n = -1
}
return n, nil
}
// Prepare turns this statement into a prepared statement to be executed
// multiple times. This invalidates any prior result sets.
func (st *statement) Prepare(_ context.Context) error {
if st.query == "" {
return adbc.Error{
Code: adbc.StatusInvalidState,
Msg: "cannot prepare statement with no query",
}
}
// snowflake doesn't provide a "Prepare" api, this is a no-op
return nil
}
// SetSubstraitPlan allows setting a serialized Substrait execution
// plan into the query or for querying Substrait-related metadata.
//
// Drivers are not required to support both SQL and Substrait semantics.
// If they do, it may be via converting between representations internally.
//
// Like SetSqlQuery, after this is called the query can be executed
// using any of the Execute methods. If the query is expected to be
// executed repeatedly, Prepare should be called first on the statement.
func (st *statement) SetSubstraitPlan(plan []byte) error {
return adbc.Error{
Msg: "Snowflake does not support Substrait plans",
Code: adbc.StatusNotImplemented,
}
}
// Bind uses an arrow record batch to bind parameters to the query.
//
// This can be used for bulk inserts or for prepared statements.
// The driver will call release on the passed in Record when it is done,
// but it may not do this until the statement is closed or another
// record is bound.
func (st *statement) Bind(_ context.Context, values arrow.Record) error {
if st.streamBind != nil {
st.streamBind.Release()
st.streamBind = nil
} else if st.bound != nil {
st.bound.Release()
st.bound = nil
}
st.bound = values
if st.bound != nil {
st.bound.Retain()
}
return nil
}
// BindStream uses a record batch stream to bind parameters for this
// query. This can be used for bulk inserts or prepared statements.
//
// The driver will call Release on the record reader, but may not do this
// until Close is called.
func (st *statement) BindStream(_ context.Context, stream array.RecordReader) error {
if st.streamBind != nil {
st.streamBind.Release()
st.streamBind = nil
} else if st.bound != nil {
st.bound.Release()
st.bound = nil
}
st.streamBind = stream
if st.streamBind != nil {
st.streamBind.Retain()
}
return nil
}
// GetParameterSchema returns an Arrow schema representation of
// the expected parameters to be bound.
//
// This retrieves an Arrow Schema describing the number, names, and
// types of the parameters in a parameterized statement. The fields
// of the schema should be in order of the ordinal position of the
// parameters; named parameters should appear only once.
//
// If the parameter does not have a name, or a name cannot be determined,
// the name of the corresponding field in the schema will be an empty
// string. If the type cannot be determined, the type of the corresponding
// field will be NA (NullType).
//
// This should be called only after calling Prepare.
//
// This should return an error with StatusNotImplemented if the schema
// cannot be determined.
func (st *statement) GetParameterSchema() (*arrow.Schema, error) {
// snowflake's API does not provide any way to determine the schema
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
}
}
// ExecutePartitions executes the current statement and gets the results
// as a partitioned result set.
//
// It returns the Schema of the result set, the collection of partition
// descriptors and the number of rows affected, if known. If unknown,
// the number of rows affected will be -1.
//
// If the driver does not support partitioned results, this will return
// an error with a StatusNotImplemented code.
func (st *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.Partitions, int64, error) {
if st.query == "" {
return nil, adbc.Partitions{}, -1, adbc.Error{
Msg: "cannot execute without a query",
Code: adbc.StatusInvalidState,
}
}
// snowflake partitioned results are not currently portable enough to
// satisfy the requirements of this function. At least not what is
// returned from the snowflake driver.
return nil, adbc.Partitions{}, -1, adbc.Error{
Msg: "ExecutePartitions not implemented for Snowflake",
Code: adbc.StatusNotImplemented,
}
}