go/adbc/driver/snowflake/binding.go (114 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package snowflake import ( "database/sql" "database/sql/driver" "fmt" "io" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" ) func convertArrowToNamedValue(batch arrow.Record, index int) ([]driver.NamedValue, error) { // see goTypeToSnowflake in gosnowflake // technically, snowflake can bind an array of values at once, but // only for INSERT, so we can't take advantage of that without // analyzing the query ourselves params := make([]driver.NamedValue, batch.NumCols()) for i, field := range batch.Schema().Fields() { rawColumn := batch.Column(i) params[i].Ordinal = i + 1 switch column := rawColumn.(type) { case *array.Boolean: params[i].Value = sql.NullBool{ Bool: column.Value(index), Valid: column.IsValid(index), } case *array.Float32: // Snowflake only recognizes float64 params[i].Value = sql.NullFloat64{ Float64: float64(column.Value(index)), Valid: column.IsValid(index), } case *array.Float64: params[i].Value = sql.NullFloat64{ Float64: column.Value(index), Valid: column.IsValid(index), } case *array.Int8: // Snowflake only recognizes int64 params[i].Value = sql.NullInt64{ Int64: int64(column.Value(index)), Valid: column.IsValid(index), } case *array.Int16: params[i].Value = sql.NullInt64{ Int64: int64(column.Value(index)), Valid: column.IsValid(index), } case *array.Int32: params[i].Value = sql.NullInt64{ Int64: int64(column.Value(index)), Valid: column.IsValid(index), } case *array.Int64: params[i].Value = sql.NullInt64{ Int64: column.Value(index), Valid: column.IsValid(index), } case *array.String: params[i].Value = sql.NullString{ String: column.Value(index), Valid: column.IsValid(index), } case *array.LargeString: params[i].Value = sql.NullString{ String: column.Value(index), Valid: column.IsValid(index), } default: return nil, adbc.Error{ Code: adbc.StatusNotImplemented, Msg: fmt.Sprintf("[Snowflake] Unsupported bind param '%s' type %s", field.Name, field.Type.String()), } } } return params, nil } type snowflakeBindReader struct { doQuery func([]driver.NamedValue) (array.RecordReader, error) currentBatch arrow.Record nextIndex int64 // may be nil if we bound only a batch stream array.RecordReader } func (r *snowflakeBindReader) Release() { if r.currentBatch != nil { r.currentBatch.Release() r.currentBatch = nil } if r.stream != nil { r.stream.Release() r.stream = nil } } func (r *snowflakeBindReader) Next() (array.RecordReader, error) { params, err := r.NextParams() if err != nil { // includes EOF return nil, err } return r.doQuery(params) } func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) { for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() { // We can be used both by binding a stream or by binding a // batch. In the latter case, we have to release the batch, // but not in the former case. Unify the cases by always // releasing the batch, adding an "extra" retain so that the // release does not cause issues. if r.currentBatch != nil { r.currentBatch.Release() } r.currentBatch = nil if r.stream != nil && r.stream.Next() { r.currentBatch = r.stream.Record() r.currentBatch.Retain() r.nextIndex = 0 continue } else if r.stream != nil && r.stream.Err() != nil { return nil, r.stream.Err() } else { // no more params return nil, io.EOF } } params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex)) r.nextIndex++ return params, err }