go/adbc/driver/snowflake/connection.go (594 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package snowflake
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strconv"
"strings"
"time"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal"
"github.com/apache/arrow/go/v13/arrow"
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/snowflakedb/gosnowflake"
)
type snowflakeConn interface {
driver.Conn
driver.ConnBeginTx
driver.ConnPrepareContext
driver.ExecerContext
driver.QueryerContext
driver.Pinger
QueryArrowStream(context.Context, string, ...driver.NamedValue) (gosnowflake.ArrowStreamLoader, error)
}
type cnxn struct {
cn snowflakeConn
db *database
ctor gosnowflake.Connector
sqldb *sql.DB
activeTransaction bool
}
// Metadata methods
// Generally these methods return an array.RecordReader that
// can be consumed to retrieve metadata about the database as Arrow
// data. The returned metadata has an expected schema given in the
// doc strings of the specific methods. Schema fields are nullable
// unless otherwise marked. While no Statement is used in these
// methods, the result set may count as an active statement to the
// driver for the purposes of concurrency management (e.g. if the
// driver has a limit on concurrent active statements and it must
// execute a SQL query internally in order to implement the metadata
// method).
//
// Some methods accept "search pattern" arguments, which are strings
// that can contain the special character "%" to match zero or more
// characters, or "_" to match exactly one character. (See the
// documentation of DatabaseMetaData in JDBC or "Pattern Value Arguments"
// in the ODBC documentation.) Escaping is not currently supported.
// GetInfo returns metadata about the database/driver.
//
// The result is an Arrow dataset with the following schema:
//
// Field Name | Field Type
// ----------------------------|-----------------------------
// info_name | uint32 not null
// info_value | INFO_SCHEMA
//
// INFO_SCHEMA is a dense union with members:
//
// Field Name (Type Code) | Field Type
// ----------------------------|-----------------------------
// string_value (0) | utf8
// bool_value (1) | bool
// int64_value (2) | int64
// int32_bitmask (3) | int32
// string_list (4) | list<utf8>
// int32_to_int32_list_map (5) | map<int32, list<int32>>
//
// Each metadatum is identified by an integer code. The recognized
// codes are defined as constants. Codes [0, 10_000) are reserved
// for ADBC usage. Drivers/vendors will ignore requests for unrecognized
// codes (the row will be omitted from the result).
func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) {
const strValTypeID arrow.UnionTypeCode = 0
if len(infoCodes) == 0 {
infoCodes = infoSupportedCodes
}
bldr := array.NewRecordBuilder(c.db.alloc, adbc.GetInfoSchema)
defer bldr.Release()
bldr.Reserve(len(infoCodes))
infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder)
for _, code := range infoCodes {
switch code {
case adbc.InfoDriverName:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverName)
case adbc.InfoDriverVersion:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverVersion)
case adbc.InfoDriverArrowVersion:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverArrowVersion)
case adbc.InfoVendorName:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoVendorName)
default:
infoNameBldr.Append(uint32(code))
infoValueBldr.AppendNull()
}
}
final := bldr.NewRecord()
defer final.Release()
return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final})
}
// GetObjects gets a hierarchical view of all catalogs, database schemas,
// tables, and columns.
//
// The result is an Arrow Dataset with the following schema:
//
// Field Name | Field Type
// ----------------------------|----------------------------
// catalog_name | utf8
// catalog_db_schemas | list<DB_SCHEMA_SCHEMA>
//
// DB_SCHEMA_SCHEMA is a Struct with the fields:
//
// Field Name | Field Type
// ----------------------------|----------------------------
// db_schema_name | utf8
// db_schema_tables | list<TABLE_SCHEMA>
//
// TABLE_SCHEMA is a Struct with the fields:
//
// Field Name | Field Type
// ----------------------------|----------------------------
// table_name | utf8 not null
// table_type | utf8 not null
// table_columns | list<COLUMN_SCHEMA>
// table_constraints | list<CONSTRAINT_SCHEMA>
//
// COLUMN_SCHEMA is a Struct with the fields:
//
// Field Name | Field Type | Comments
// ----------------------------|---------------------|---------
// column_name | utf8 not null |
// ordinal_position | int32 | (1)
// remarks | utf8 | (2)
// xdbc_data_type | int16 | (3)
// xdbc_type_name | utf8 | (3)
// xdbc_column_size | int32 | (3)
// xdbc_decimal_digits | int16 | (3)
// xdbc_num_prec_radix | int16 | (3)
// xdbc_nullable | int16 | (3)
// xdbc_column_def | utf8 | (3)
// xdbc_sql_data_type | int16 | (3)
// xdbc_datetime_sub | int16 | (3)
// xdbc_char_octet_length | int32 | (3)
// xdbc_is_nullable | utf8 | (3)
// xdbc_scope_catalog | utf8 | (3)
// xdbc_scope_schema | utf8 | (3)
// xdbc_scope_table | utf8 | (3)
// xdbc_is_autoincrement | bool | (3)
// xdbc_is_generatedcolumn | bool | (3)
//
// 1. The column's ordinal position in the table (starting from 1).
// 2. Database-specific description of the column.
// 3. Optional Value. Should be null if not supported by the driver.
// xdbc_values are meant to provide JDBC/ODBC-compatible metadata
// in an agnostic manner.
//
// CONSTRAINT_SCHEMA is a Struct with the fields:
//
// Field Name | Field Type | Comments
// ----------------------------|---------------------|---------
// constraint_name | utf8 |
// constraint_type | utf8 not null | (1)
// constraint_column_names | list<utf8> not null | (2)
// constraint_column_usage | list<USAGE_SCHEMA> | (3)
//
// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'.
// 2. The columns on the current table that are constrained, in order.
// 3. For FOREIGN KEY only, the referenced table and columns.
//
// USAGE_SCHEMA is a Struct with fields:
//
// Field Name | Field Type
// ----------------------------|----------------------------
// fk_catalog | utf8
// fk_db_schema | utf8
// fk_table | utf8 not null
// fk_column_name | utf8 not null
//
// For the parameters: If nil is passed, then that parameter will not
// be filtered by at all. If an empty string, then only objects without
// that property (ie: catalog or db schema) will be returned.
//
// tableName and columnName must be either nil (do not filter by
// table name or column name) or non-empty.
//
// All non-empty, non-nil strings should be a search pattern (as described
// earlier).
func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) {
g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType}
if err := g.Init(c.db.alloc, c.getObjectsDbSchemas, c.getObjectsTables); err != nil {
return nil, err
}
defer g.Release()
rows, err := c.sqldb.QueryContext(ctx, "SHOW TERSE DATABASES", nil)
if err != nil {
return nil, err
}
defer rows.Close()
var (
created time.Time
name string
kind, dbname, schema sql.NullString
)
for rows.Next() {
if err := rows.Scan(&created, &name, &kind, &dbname, &schema); err != nil {
return nil, errToAdbcErr(adbc.StatusInvalidData, err)
}
// SNOWFLAKE catalog contains functions and no tables
if name == "SNOWFLAKE" {
continue
}
// schema for SHOW TERSE DATABASES is:
// created_on:timestamp, name:text, kind:null, database_name:null, schema_name:null
// the last three columns are always null because they are not applicable for databases
// so we want values[1].(string) for the name
g.AppendCatalog(name)
}
return g.Finish()
}
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
conditions := make([]string, 0)
if catalog != nil && *catalog != "" {
conditions = append(conditions, ` CATALOG_NAME LIKE \'`+*catalog+`\'`)
}
if dbSchema != nil && *dbSchema != "" {
conditions = append(conditions, ` SCHEMA_NAME LIKE \'`+*dbSchema+`\'`)
}
cond := strings.Join(conditions, " AND ")
if cond != "" {
cond = `statement := 'SELECT * FROM (' || statement || ') WHERE ` + cond + `';`
}
result = make(map[string][]string)
const queryPrefix = `DECLARE
c1 CURSOR FOR SELECT DATABASE_NAME FROM INFORMATION_SCHEMA.DATABASES;
res RESULTSET;
counter INTEGER DEFAULT 0;
statement VARCHAR DEFAULT '';
BEGIN
FOR rec IN c1 DO
LET sharelist RESULTSET := (EXECUTE IMMEDIATE 'SHOW SHARES LIKE \'%' || rec.database_name || '%\'');
LET cnt RESULTSET := (SELECT COUNT(*) FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())));
LET cnt_cur CURSOR for cnt;
LET share_cnt INTEGER DEFAULT 0;
OPEN cnt_cur;
FETCH cnt_cur INTO share_cnt;
CLOSE cnt_cur;
IF (share_cnt > 0) THEN
LET c2 CURSOR for sharelist;
LET created_on TIMESTAMP;
LET kind VARCHAR DEFAULT '';
LET share_name VARCHAR DEFAULT '';
LET dbname VARCHAR DEFAULT '';
OPEN c2;
FETCH c2 INTO created_on, kind, share_name, dbname;
CLOSE c2;
IF (dbname = '') THEN
CONTINUE;
END IF;
END IF;
IF (counter > 0) THEN
statement := statement || ' UNION ALL ';
END IF;
statement := statement || ' SELECT CATALOG_NAME, SCHEMA_NAME FROM ' || rec.database_name || '.INFORMATION_SCHEMA.SCHEMATA';
counter := counter + 1;
END FOR;
`
const querySuffix = `
res := (EXECUTE IMMEDIATE :statement);
RETURN TABLE (res);
END;`
query := queryPrefix + cond + querySuffix
var rows *sql.Rows
rows, err = c.sqldb.QueryContext(ctx, query)
if err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
}
defer rows.Close()
var catalogName, schemaName string
for rows.Next() {
if err = rows.Scan(&catalogName, &schemaName); err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
}
cat, ok := result[catalogName]
if !ok {
cat = make([]string, 0, 1)
}
result[catalogName] = append(cat, schemaName)
}
return
}
var loc = time.Now().Location()
func toField(name string, isnullable bool, dataType string, numPrec, numPrecRadix, numScale sql.NullInt16, isIdent bool, identGen, identInc, comment sql.NullString, ordinalPos int) (ret arrow.Field) {
ret.Name, ret.Nullable = name, isnullable
switch dataType {
case "NUMBER":
if !numScale.Valid || numScale.Int16 == 0 {
ret.Type = arrow.PrimitiveTypes.Int64
} else {
ret.Type = arrow.PrimitiveTypes.Float64
}
case "FLOAT":
fallthrough
case "DOUBLE":
ret.Type = arrow.PrimitiveTypes.Float64
case "TEXT":
ret.Type = arrow.BinaryTypes.String
case "BINARY":
ret.Type = arrow.BinaryTypes.Binary
case "BOOLEAN":
ret.Type = arrow.FixedWidthTypes.Boolean
case "ARRAY":
fallthrough
case "VARIANT":
fallthrough
case "OBJECT":
// snowflake will return each value as a string
ret.Type = arrow.BinaryTypes.String
case "DATE":
ret.Type = arrow.FixedWidthTypes.Date32
case "TIME":
ret.Type = arrow.FixedWidthTypes.Time64ns
case "DATETIME":
fallthrough
case "TIMESTAMP", "TIMESTAMP_NTZ":
ret.Type = &arrow.TimestampType{Unit: arrow.Nanosecond}
case "TIMESTAMP_LTZ":
ret.Type = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()}
case "TIMESTAMP_TZ":
ret.Type = arrow.FixedWidthTypes.Timestamp_ns
case "GEOGRAPHY":
fallthrough
case "GEOMETRY":
ret.Type = arrow.BinaryTypes.String
}
md := make(map[string]string)
md["TYPE_NAME"] = dataType
if isIdent {
md["IS_IDENTITY"] = "YES"
md["IDENTITY_GENERATION"] = identGen.String
md["IDENTITY_INCREMENT"] = identInc.String
}
if comment.Valid {
md["COMMENT"] = comment.String
}
md["ORDINAL_POSITION"] = strconv.Itoa(ordinalPos)
ret.Metadata = arrow.MetadataFrom(md)
return
}
func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas {
return
}
result = make(internal.SchemaToTableInfo)
includeSchema := depth == adbc.ObjectDepthAll || depth == adbc.ObjectDepthColumns
conditions := make([]string, 0)
if catalog != nil && *catalog != "" {
conditions = append(conditions, ` TABLE_CATALOG ILIKE \'`+*catalog+`\'`)
}
if dbSchema != nil && *dbSchema != "" {
conditions = append(conditions, ` TABLE_SCHEMA ILIKE \'`+*dbSchema+`\'`)
}
if tableName != nil && *tableName != "" {
conditions = append(conditions, ` TABLE_NAME ILIKE \'`+*tableName+`\'`)
}
const queryPrefix = `DECLARE
c1 CURSOR FOR SELECT DATABASE_NAME FROM INFORMATION_SCHEMA.DATABASES;
res RESULTSET;
counter INTEGER DEFAULT 0;
statement VARCHAR DEFAULT '';
BEGIN
FOR rec IN c1 DO
LET sharelist RESULTSET := (EXECUTE IMMEDIATE 'SHOW SHARES LIKE \'%' || rec.database_name || '%\'');
LET cnt RESULTSET := (SELECT COUNT(*) FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())));
LET cnt_cur CURSOR for cnt;
LET share_cnt INTEGER DEFAULT 0;
OPEN cnt_cur;
FETCH cnt_cur INTO share_cnt;
CLOSE cnt_cur;
IF (share_cnt > 0) THEN
LET c2 CURSOR for sharelist;
LET created_on TIMESTAMP;
LET kind VARCHAR DEFAULT '';
LET share_name VARCHAR DEFAULT '';
LET dbname VARCHAR DEFAULT '';
OPEN c2;
FETCH c2 INTO created_on, kind, share_name, dbname;
CLOSE c2;
IF (dbname = '') THEN
CONTINUE;
END IF;
END IF;
IF (counter > 0) THEN
statement := statement || ' UNION ALL ';
END IF;
`
const noSchema = `statement := statement || ' SELECT table_catalog, table_schema, table_name, table_type FROM ' || rec.database_name || '.INFORMATION_SCHEMA.TABLES';
counter := counter + 1;
END FOR;
`
const getSchema = `statement := statement ||
' SELECT
table_catalog, table_schema, table_name, column_name,
ordinal_position, is_nullable::boolean, data_type, numeric_precision,
numeric_precision_radix, numeric_scale, is_identity::boolean,
identity_generation, identity_increment, comment
FROM ' || rec.database_name || '.INFORMATION_SCHEMA.COLUMNS';
counter := counter + 1;
END FOR;
`
const querySuffix = `
res := (EXECUTE IMMEDIATE :statement);
RETURN TABLE (res);
END;`
// first populate the tables and table types
var rows *sql.Rows
var tblConditions []string
if len(tableType) > 0 {
tblConditions = append(conditions, ` TABLE_TYPE IN (\'`+strings.Join(tableType, `\',\'`)+`\')`)
} else {
tblConditions = conditions
}
cond := strings.Join(tblConditions, " AND ")
if cond != "" {
cond = `statement := 'SELECT * FROM (' || statement || ') WHERE ` + cond + `';`
}
query := queryPrefix + noSchema + cond + querySuffix
rows, err = c.sqldb.QueryContext(ctx, query)
if err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
}
defer rows.Close()
var tblCat, tblSchema, tblName string
var tblType sql.NullString
for rows.Next() {
if err = rows.Scan(&tblCat, &tblSchema, &tblName, &tblType); err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
}
key := internal.CatalogAndSchema{
Catalog: tblCat, Schema: tblSchema}
result[key] = append(result[key], internal.TableInfo{
Name: tblName, TableType: tblType.String})
}
if includeSchema {
// if we need to include the schemas of the tables, make another fetch
// to fetch the columns and column info
if columnName != nil && *columnName != "" {
conditions = append(conditions, ` column_name ILIKE \'`+*columnName+`\'`)
}
cond = strings.Join(conditions, " AND ")
if cond != "" {
cond = " WHERE " + cond
}
cond = `statement := 'SELECT * FROM (' || statement || ')` + cond +
` ORDER BY table_catalog, table_schema, table_name, ordinal_position';`
query = queryPrefix + getSchema + cond + querySuffix
rows, err = c.sqldb.QueryContext(ctx, query)
if err != nil {
return
}
defer rows.Close()
var (
colName, dataType string
identGen, identIncrement, comment sql.NullString
ordinalPos int
numericPrec, numericPrecRadix, numericScale sql.NullInt16
isNullable, isIdent bool
prevKey internal.CatalogAndSchema
curTableInfo *internal.TableInfo
fieldList = make([]arrow.Field, 0)
)
for rows.Next() {
// order here matches the order of the columns requested in the query
err = rows.Scan(&tblCat, &tblSchema, &tblName, &colName,
&ordinalPos, &isNullable, &dataType, &numericPrec,
&numericPrecRadix, &numericScale, &isIdent, &identGen,
&identIncrement, &comment)
if err != nil {
err = errToAdbcErr(adbc.StatusIO, err)
return
}
key := internal.CatalogAndSchema{Catalog: tblCat, Schema: tblSchema}
if prevKey != key || (curTableInfo != nil && curTableInfo.Name != tblName) {
if len(fieldList) > 0 && curTableInfo != nil {
curTableInfo.Schema = arrow.NewSchema(fieldList, nil)
fieldList = fieldList[:0]
}
info := result[key]
for i := range info {
if info[i].Name == tblName {
curTableInfo = &info[i]
break
}
}
}
prevKey = key
fieldList = append(fieldList, toField(colName, isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent, identGen, identIncrement, comment, ordinalPos))
}
if len(fieldList) > 0 && curTableInfo != nil {
curTableInfo.Schema = arrow.NewSchema(fieldList, nil)
}
}
return
}
func descToField(name, typ, isnull, primary string, comment sql.NullString) (field arrow.Field, err error) {
field.Name = strings.ToLower(name)
if isnull == "Y" {
field.Nullable = true
}
md := make(map[string]string)
md["DATA_TYPE"] = typ
md["PRIMARY_KEY"] = primary
if comment.Valid {
md["COMMENT"] = comment.String
}
field.Metadata = arrow.MetadataFrom(md)
paren := strings.Index(typ, "(")
if paren == -1 {
// types without params
switch typ {
case "FLOAT":
fallthrough
case "DOUBLE":
field.Type = arrow.PrimitiveTypes.Float64
case "DATE":
field.Type = arrow.FixedWidthTypes.Date32
// array, object and variant are all represented as strings by
// snowflake's return
case "ARRAY":
fallthrough
case "OBJECT":
fallthrough
case "VARIANT":
field.Type = arrow.BinaryTypes.String
case "GEOGRAPHY":
fallthrough
case "GEOMETRY":
field.Type = arrow.BinaryTypes.String
case "BOOLEAN":
field.Type = arrow.FixedWidthTypes.Boolean
default:
err = adbc.Error{
Msg: fmt.Sprintf("Snowflake Data Type %s not implemented", typ),
Code: adbc.StatusNotImplemented,
}
}
return
}
prefix := typ[:paren]
switch prefix {
case "VARCHAR", "TEXT":
field.Type = arrow.BinaryTypes.String
case "BINARY", "VARBINARY":
field.Type = arrow.BinaryTypes.Binary
case "NUMBER":
comma := strings.Index(typ, ",")
scale, err := strconv.ParseInt(typ[comma+1:len(typ)-1], 10, 32)
if err != nil {
return field, adbc.Error{
Msg: "could not parse Scale from type '" + typ + "'",
Code: adbc.StatusInvalidData,
}
}
if scale == 0 {
field.Type = arrow.PrimitiveTypes.Int64
} else {
field.Type = arrow.PrimitiveTypes.Float64
}
case "TIME":
field.Type = arrow.FixedWidthTypes.Time64ns
case "DATETIME":
fallthrough
case "TIMESTAMP", "TIMESTAMP_NTZ":
field.Type = &arrow.TimestampType{Unit: arrow.Nanosecond}
case "TIMESTAMP_LTZ":
field.Type = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()}
case "TIMESTAMP_TZ":
field.Type = arrow.FixedWidthTypes.Timestamp_ns
default:
err = adbc.Error{
Msg: fmt.Sprintf("Snowflake Data Type %s not implemented", typ),
Code: adbc.StatusNotImplemented,
}
}
return
}
func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) {
tblParts := make([]string, 0, 3)
if catalog != nil {
tblParts = append(tblParts, strconv.Quote(strings.ToUpper(*catalog)))
}
if dbSchema != nil {
tblParts = append(tblParts, strconv.Quote(strings.ToUpper(*dbSchema)))
}
tblParts = append(tblParts, strconv.Quote(strings.ToUpper(tableName)))
fullyQualifiedTable := strings.Join(tblParts, ".")
rows, err := c.sqldb.QueryContext(ctx, `DESC TABLE `+fullyQualifiedTable)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
defer rows.Close()
var (
name, typ, kind, isnull, primary, unique string
def, check, expr, comment, policyName sql.NullString
fields = []arrow.Field{}
)
for rows.Next() {
err := rows.Scan(&name, &typ, &kind, &isnull, &def, &primary, &unique,
&check, &expr, &comment, &policyName)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
f, err := descToField(name, typ, isnull, primary, comment)
if err != nil {
return nil, err
}
fields = append(fields, f)
}
sc := arrow.NewSchema(fields, nil)
return sc, nil
}
// GetTableTypes returns a list of the table types in the database.
//
// The result is an arrow dataset with the following schema:
//
// Field Name | Field Type
// ----------------|--------------
// table_type | utf8 not null
func (c *cnxn) GetTableTypes(_ context.Context) (array.RecordReader, error) {
bldr := array.NewRecordBuilder(c.db.alloc, adbc.TableTypesSchema)
defer bldr.Release()
bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil)
final := bldr.NewRecord()
defer final.Release()
return array.NewRecordReader(adbc.TableTypesSchema, []arrow.Record{final})
}
// Commit commits any pending transactions on this connection, it should
// only be used if autocommit is disabled.
//
// Behavior is undefined if this is mixed with SQL transaction statements.
func (c *cnxn) Commit(_ context.Context) error {
if !c.activeTransaction {
return adbc.Error{
Msg: "no active transaction, cannot commit",
Code: adbc.StatusInvalidState,
}
}
_, err := c.cn.ExecContext(context.Background(), "COMMIT", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
_, err = c.cn.ExecContext(context.Background(), "BEGIN", nil)
return errToAdbcErr(adbc.StatusInternal, err)
}
// Rollback rolls back any pending transactions. Only used if autocommit
// is disabled.
//
// Behavior is undefined if this is mixed with SQL transaction statements.
func (c *cnxn) Rollback(_ context.Context) error {
if !c.activeTransaction {
return adbc.Error{
Msg: "no active transaction, cannot rollback",
Code: adbc.StatusInvalidState,
}
}
_, err := c.cn.ExecContext(context.Background(), "ROLLBACK", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
_, err = c.cn.ExecContext(context.Background(), "BEGIN", nil)
return errToAdbcErr(adbc.StatusInternal, err)
}
// NewStatement initializes a new statement object tied to this connection
func (c *cnxn) NewStatement() (adbc.Statement, error) {
return &statement{
alloc: c.db.alloc,
cnxn: c,
}, nil
}
// Close closes this connection and releases any associated resources.
func (c *cnxn) Close() error {
if c.sqldb == nil || c.cn == nil {
return adbc.Error{Code: adbc.StatusInvalidState}
}
if err := c.sqldb.Close(); err != nil {
return err
}
c.sqldb = nil
defer func() {
c.cn = nil
}()
return c.cn.Close()
}
// ReadPartition constructs a statement for a partition of a query. The
// results can then be read independently using the returned RecordReader.
//
// A partition can be retrieved by using ExecutePartitions on a statement.
func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) {
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: "ReadPartition not yet implemented for snowflake driver",
}
}
func (c *cnxn) SetOption(key, value string) error {
switch key {
case adbc.OptionKeyAutoCommit:
switch value {
case adbc.OptionValueEnabled:
if c.activeTransaction {
_, err := c.cn.ExecContext(context.Background(), "COMMIT", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
c.activeTransaction = false
}
_, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = true", nil)
return err
case adbc.OptionValueDisabled:
if !c.activeTransaction {
_, err := c.cn.ExecContext(context.Background(), "BEGIN", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
c.activeTransaction = true
}
_, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = false", nil)
return err
default:
return adbc.Error{
Msg: "[Snowflake] invalid value for option " + key + ": " + value,
Code: adbc.StatusInvalidArgument,
}
}
default:
return adbc.Error{
Msg: "[Snowflake] unknown connection option " + key + ": " + value,
Code: adbc.StatusInvalidArgument,
}
}
}