go/adbc/driver/snowflake/statement.go (575 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package snowflake import ( "context" "database/sql/driver" "fmt" "io" "strconv" "strings" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/snowflakedb/gosnowflake" ) const ( OptionStatementQueryTag = "adbc.snowflake.statement.query_tag" OptionStatementQueueSize = "adbc.rpc.result_queue_size" OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency" OptionStatementIngestWriterConcurrency = "adbc.snowflake.statement.ingest_writer_concurrency" OptionStatementIngestUploadConcurrency = "adbc.snowflake.statement.ingest_upload_concurrency" OptionStatementIngestCopyConcurrency = "adbc.snowflake.statement.ingest_copy_concurrency" OptionStatementIngestTargetFileSize = "adbc.snowflake.statement.ingest_target_file_size" OptionStatementIngestCompressionCodec = "adbc.snowflake.statement.ingest_compression_codec" // TODO(GH-1473): Implement option OptionStatementIngestCompressionLevel = "adbc.snowflake.statement.ingest_compression_level" // TODO(GH-1473): Implement option ) type statement struct { cnxn *connectionImpl alloc memory.Allocator queueSize int prefetchConcurrency int useHighPrecision bool query string targetTable string ingestMode string ingestOptions *ingestOptions queryTag string bound arrow.Record streamBind array.RecordReader } // setQueryContext applies the query tag if present. func (st *statement) setQueryContext(ctx context.Context) context.Context { if st.queryTag != "" { ctx = gosnowflake.WithQueryTag(ctx, st.queryTag) } return ctx } // Close releases any relevant resources associated with this statement // and closes it (particularly if it is a prepared statement). // // A statement instance should not be used after Close is called. func (st *statement) Close() error { if st.cnxn == nil { return adbc.Error{ Msg: "statement already closed", Code: adbc.StatusInvalidState} } if st.bound != nil { st.bound.Release() st.bound = nil } else if st.streamBind != nil { st.streamBind.Release() st.streamBind = nil } st.cnxn = nil return nil } func (st *statement) GetOption(key string) (string, error) { switch key { case OptionStatementQueryTag: return st.queryTag, nil } return "", adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (st *statement) GetOptionBytes(key string) ([]byte, error) { return nil, adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (st *statement) GetOptionInt(key string) (int64, error) { switch key { case OptionStatementQueueSize: return int64(st.queueSize), nil } return 0, adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (st *statement) GetOptionDouble(key string) (float64, error) { return 0, adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } // SetOption sets a string option on this statement func (st *statement) SetOption(key string, val string) error { switch key { case adbc.OptionKeyIngestTargetTable: st.query = "" st.targetTable = val case adbc.OptionKeyIngestMode: switch val { case adbc.OptionValueIngestModeAppend: fallthrough case adbc.OptionValueIngestModeCreate: fallthrough case adbc.OptionValueIngestModeReplace: fallthrough case adbc.OptionValueIngestModeCreateAppend: st.ingestMode = val default: return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] invalid statement option %s=%s", key, val), Code: adbc.StatusInvalidArgument, } } case OptionStatementQueueSize: sz, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(sz)) case OptionStatementPrefetchConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(concurrency)) case OptionStatementIngestWriterConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(concurrency)) case OptionStatementIngestUploadConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(concurrency)) case OptionStatementIngestCopyConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(concurrency)) case OptionStatementIngestTargetFileSize: size, err := strconv.Atoi(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] could not parse '%s' as int for option '%s'", val, key), Code: adbc.StatusInvalidArgument, } } return st.SetOptionInt(key, int64(size)) case OptionStatementQueryTag: st.queryTag = val return nil case OptionUseHighPrecision: switch val { case adbc.OptionValueEnabled: st.useHighPrecision = true case adbc.OptionValueDisabled: st.useHighPrecision = false default: return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] invalid statement option %s=%s", key, val), Code: adbc.StatusInvalidArgument, } } default: return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } return nil } func (st *statement) SetOptionBytes(key string, value []byte) error { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } func (st *statement) SetOptionInt(key string, value int64) error { switch key { case OptionStatementQueueSize: if value <= 0 { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), Code: adbc.StatusInvalidArgument, } } st.queueSize = int(value) return nil case OptionStatementPrefetchConcurrency: if value <= 0 { return adbc.Error{ Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", value, key), Code: adbc.StatusInvalidArgument, } } st.prefetchConcurrency = int(value) return nil case OptionStatementIngestWriterConcurrency: if value < 0 { return adbc.Error{ Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), Code: adbc.StatusInvalidArgument, } } if value == 0 { st.ingestOptions.writerConcurrency = defaultWriterConcurrency return nil } st.ingestOptions.writerConcurrency = uint(value) return nil case OptionStatementIngestUploadConcurrency: if value < 0 { return adbc.Error{ Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), Code: adbc.StatusInvalidArgument, } } if value == 0 { st.ingestOptions.uploadConcurrency = defaultUploadConcurrency return nil } st.ingestOptions.uploadConcurrency = uint(value) return nil case OptionStatementIngestCopyConcurrency: if value < 0 { return adbc.Error{ Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), Code: adbc.StatusInvalidArgument, } } st.ingestOptions.copyConcurrency = uint(value) return nil case OptionStatementIngestTargetFileSize: if value < 0 { return adbc.Error{ Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be >= 0", value, key), Code: adbc.StatusInvalidArgument, } } st.ingestOptions.targetFileSize = uint(value) return nil } return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } func (st *statement) SetOptionDouble(key string, value float64) error { return adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. // For queries expected to be executed repeatedly, Prepare should be // called before execution. func (st *statement) SetSqlQuery(query string) error { st.query = query st.targetTable = "" return nil } func toSnowflakeType(dt arrow.DataType) string { switch dt.ID() { case arrow.EXTENSION: return toSnowflakeType(dt.(arrow.ExtensionType).StorageType()) case arrow.DICTIONARY: return toSnowflakeType(dt.(*arrow.DictionaryType).ValueType) case arrow.RUN_END_ENCODED: return toSnowflakeType(dt.(*arrow.RunEndEncodedType).Encoded()) case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64: return "integer" case arrow.FLOAT32, arrow.FLOAT16, arrow.FLOAT64: return "double" case arrow.DECIMAL, arrow.DECIMAL256: dec := dt.(arrow.DecimalType) return fmt.Sprintf("NUMERIC(%d,%d)", dec.GetPrecision(), dec.GetScale()) case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW: return "text" case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW: return "binary" case arrow.FIXED_SIZE_BINARY: fsb := dt.(*arrow.FixedSizeBinaryType) return fmt.Sprintf("binary(%d)", fsb.ByteWidth) case arrow.BOOL: return "boolean" case arrow.TIME32, arrow.TIME64: t := dt.(arrow.TemporalWithUnit) prec := int(t.TimeUnit()) * 3 return fmt.Sprintf("time(%d)", prec) case arrow.DATE32, arrow.DATE64: return "date" case arrow.TIMESTAMP: ts := dt.(*arrow.TimestampType) prec := int(ts.Unit) * 3 if ts.TimeZone == "" { return fmt.Sprintf("timestamp_ntz(%d)", prec) } return fmt.Sprintf("timestamp_ltz(%d)", prec) case arrow.DENSE_UNION, arrow.SPARSE_UNION: return "variant" case arrow.LIST, arrow.LARGE_LIST, arrow.FIXED_SIZE_LIST: return "array" case arrow.STRUCT, arrow.MAP: return "object" } return "" } func (st *statement) initIngest(ctx context.Context) error { var ( createBldr strings.Builder ) createBldr.WriteString("CREATE TABLE ") if st.ingestMode == adbc.OptionValueIngestModeCreateAppend { createBldr.WriteString(" IF NOT EXISTS ") } createBldr.WriteString(quoteTblName(st.targetTable)) createBldr.WriteString(" (") var schema *arrow.Schema if st.bound != nil { schema = st.bound.Schema() } else { schema = st.streamBind.Schema() } for i, f := range schema.Fields() { if i != 0 { createBldr.WriteString(", ") } createBldr.WriteString(quoteTblName(f.Name)) createBldr.WriteString(" ") ty := toSnowflakeType(f.Type) if ty == "" { return adbc.Error{ Msg: fmt.Sprintf("unimplemented type conversion for field %s, arrow type: %s", f.Name, f.Type), Code: adbc.StatusNotImplemented, } } createBldr.WriteString(ty) if !f.Nullable { createBldr.WriteString(" NOT NULL") } } createBldr.WriteString(")") switch st.ingestMode { case adbc.OptionValueIngestModeAppend: // Do nothing case adbc.OptionValueIngestModeReplace: replaceQuery := "DROP TABLE IF EXISTS " + quoteTblName(st.targetTable) _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) } fallthrough case adbc.OptionValueIngestModeCreate: fallthrough case adbc.OptionValueIngestModeCreateAppend: fallthrough default: // create the table! createQuery := createBldr.String() _, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) } } return nil } func (st *statement) executeIngest(ctx context.Context) (int64, error) { if st.streamBind == nil && st.bound == nil { return -1, adbc.Error{ Msg: "must call Bind before bulk ingestion", Code: adbc.StatusInvalidState, } } err := st.initIngest(ctx) if err != nil { return -1, err } if st.bound != nil { return st.ingestRecord(ctx) } return st.ingestStream(ctx) } // ExecuteQuery executes the current query or prepared statement // and returnes a RecordReader for the results along with the number // of rows affected if known, otherwise it will be -1. // // This invalidates any prior result sets on this statement. func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error) { ctx = st.setQueryContext(ctx) if st.targetTable != "" { n, err := st.executeIngest(ctx) return nil, n, err } if st.query == "" { return nil, -1, adbc.Error{ Msg: "cannot execute without a query", Code: adbc.StatusInvalidState, } } // for a bound stream reader we'd need to implement something to // concatenate RecordReaders which doesn't exist yet. let's put // that off for now. if st.streamBind != nil || st.bound != nil { bind := snowflakeBindReader{ doQuery: func(params []driver.NamedValue) (array.RecordReader, error) { loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query, params...) if err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) } return newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision) }, currentBatch: st.bound, stream: st.streamBind, } st.bound = nil st.streamBind = nil rdr := concatReader{} err := rdr.Init(&bind) if err != nil { return nil, -1, err } return &rdr, -1, nil } loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query) if err != nil { return nil, -1, errToAdbcErr(adbc.StatusInternal, err) } rdr, err := newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision) nrec := loader.TotalRows() return rdr, nrec, err } // ExecuteUpdate executes a statement that does not generate a result // set. It returns the number of rows affected if known, otherwise -1. func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) { ctx = st.setQueryContext(ctx) if st.targetTable != "" { return st.executeIngest(ctx) } if st.query == "" { return -1, adbc.Error{ Msg: "cannot execute without a query", Code: adbc.StatusInvalidState, } } if st.streamBind != nil || st.bound != nil { numRows := int64(0) bind := snowflakeBindReader{ currentBatch: st.bound, stream: st.streamBind, } st.bound = nil st.streamBind = nil defer bind.Release() for { params, err := bind.NextParams() if err == io.EOF { break } else if err != nil { return -1, err } r, err := st.cnxn.cn.ExecContext(ctx, st.query, params) if err != nil { return -1, errToAdbcErr(adbc.StatusInternal, err) } n, err := r.RowsAffected() if err != nil { numRows = -1 } else if numRows >= 0 { numRows += n } } return numRows, nil } r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil) if err != nil { return -1, errToAdbcErr(adbc.StatusIO, err) } n, err := r.RowsAffected() if err != nil { n = -1 } return n, nil } // ExecuteSchema gets the schema of the result set of a query without executing it. func (st *statement) ExecuteSchema(ctx context.Context) (*arrow.Schema, error) { ctx = st.setQueryContext(ctx) if st.targetTable != "" { return nil, adbc.Error{ Msg: "cannot execute schema for ingestion", Code: adbc.StatusInvalidState, } } if st.query == "" { return nil, adbc.Error{ Msg: "cannot execute without a query", Code: adbc.StatusInvalidState, } } if st.streamBind != nil || st.bound != nil { return nil, adbc.Error{ Msg: "executing schema with bound params not yet implemented", Code: adbc.StatusNotImplemented, } } loader, err := st.cnxn.cn.QueryArrowStream(gosnowflake.WithDescribeOnly(ctx), st.query) if err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) } return rowTypesToArrowSchema(ctx, loader, st.useHighPrecision) } // Prepare turns this statement into a prepared statement to be executed // multiple times. This invalidates any prior result sets. func (st *statement) Prepare(_ context.Context) error { if st.query == "" { return adbc.Error{ Code: adbc.StatusInvalidState, Msg: "cannot prepare statement with no query", } } // snowflake doesn't provide a "Prepare" api, this is a no-op return nil } // SetSubstraitPlan allows setting a serialized Substrait execution // plan into the query or for querying Substrait-related metadata. // // Drivers are not required to support both SQL and Substrait semantics. // If they do, it may be via converting between representations internally. // // Like SetSqlQuery, after this is called the query can be executed // using any of the Execute methods. If the query is expected to be // executed repeatedly, Prepare should be called first on the statement. func (st *statement) SetSubstraitPlan(plan []byte) error { return adbc.Error{ Msg: "Snowflake does not support Substrait plans", Code: adbc.StatusNotImplemented, } } // Bind uses an arrow record batch to bind parameters to the query. // // This can be used for bulk inserts or for prepared statements. // The driver will call release on the passed in Record when it is done, // but it may not do this until the statement is closed or another // record is bound. func (st *statement) Bind(_ context.Context, values arrow.Record) error { if st.streamBind != nil { st.streamBind.Release() st.streamBind = nil } else if st.bound != nil { st.bound.Release() st.bound = nil } st.bound = values if st.bound != nil { st.bound.Retain() } return nil } // BindStream uses a record batch stream to bind parameters for this // query. This can be used for bulk inserts or prepared statements. // // The driver will call Release on the record reader, but may not do this // until Close is called. func (st *statement) BindStream(_ context.Context, stream array.RecordReader) error { if st.streamBind != nil { st.streamBind.Release() st.streamBind = nil } else if st.bound != nil { st.bound.Release() st.bound = nil } st.streamBind = stream if st.streamBind != nil { st.streamBind.Retain() } return nil } // GetParameterSchema returns an Arrow schema representation of // the expected parameters to be bound. // // This retrieves an Arrow Schema describing the number, names, and // types of the parameters in a parameterized statement. The fields // of the schema should be in order of the ordinal position of the // parameters; named parameters should appear only once. // // If the parameter does not have a name, or a name cannot be determined, // the name of the corresponding field in the schema will be an empty // string. If the type cannot be determined, the type of the corresponding // field will be NA (NullType). // // This should be called only after calling Prepare. // // This should return an error with StatusNotImplemented if the schema // cannot be determined. func (st *statement) GetParameterSchema() (*arrow.Schema, error) { // snowflake's API does not provide any way to determine the schema return nil, adbc.Error{ Code: adbc.StatusNotImplemented, } } // ExecutePartitions executes the current statement and gets the results // as a partitioned result set. // // It returns the Schema of the result set, the collection of partition // descriptors and the number of rows affected, if known. If unknown, // the number of rows affected will be -1. // // If the driver does not support partitioned results, this will return // an error with a StatusNotImplemented code. func (st *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.Partitions, int64, error) { if st.query == "" { return nil, adbc.Partitions{}, -1, adbc.Error{ Msg: "cannot execute without a query", Code: adbc.StatusInvalidState, } } // snowflake partitioned results are not currently portable enough to // satisfy the requirements of this function. At least not what is // returned from the snowflake driver. return nil, adbc.Partitions{}, -1, adbc.Error{ Msg: "ExecutePartitions not implemented for Snowflake", Code: adbc.StatusNotImplemented, } }