go/adbc/driver/flightsql/flightsql_connection.go (954 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package flightsql import ( "bytes" "context" "encoding/json" "fmt" "io" "math" "strings" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal" "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" flightproto "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/bluele/gcache" "google.golang.org/grpc" grpccodes "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" grpcstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) type connectionImpl struct { driverbase.ConnectionImplBase cl *flightsql.Client db *databaseImpl clientCache gcache.Cache hdrs metadata.MD timeouts timeoutOption txn *flightsql.Txn supportInfo support } type flightSqlMetadata struct { internal.DefaultXdbcMetadataBuilder columnMetadata *flightsql.ColumnMetadata } func (md *flightSqlMetadata) SetMetadata(metadata arrow.Metadata) { md.columnMetadata = &flightsql.ColumnMetadata{Data: &metadata} } func (md *flightSqlMetadata) SetXdbcScopeCatalog(b *array.StringBuilder) { if v, ok := md.columnMetadata.CatalogName(); ok { b.Append(v) } else { md.DefaultXdbcMetadataBuilder.SetXdbcScopeCatalog(b) } } func (md *flightSqlMetadata) SetXdbcScopeSchema(b *array.StringBuilder) { if v, ok := md.columnMetadata.SchemaName(); ok { b.Append(v) } else { md.DefaultXdbcMetadataBuilder.SetXdbcScopeSchema(b) } } func (md *flightSqlMetadata) SetXdbcScopeTable(b *array.StringBuilder) { if v, ok := md.columnMetadata.TableName(); ok { b.Append(v) } else { md.DefaultXdbcMetadataBuilder.SetXdbcScopeTable(b) } } func (md *flightSqlMetadata) SetXdbcSqlDataType(columnType arrow.DataType, b *array.Int16Builder) { b.Append(int16(internal.ToXdbcDataType(columnType))) } func (md *flightSqlMetadata) SetXdbcTypeName(b *array.StringBuilder) { if v, ok := md.columnMetadata.TypeName(); ok { b.Append(v) } else { md.DefaultXdbcMetadataBuilder.SetXdbcTypeName(b) } } func (md *flightSqlMetadata) SetXdbcIsAutoincrement(builder *array.BooleanBuilder) { if v, ok := md.columnMetadata.IsAutoIncrement(); ok { builder.Append(v) } else { md.DefaultXdbcMetadataBuilder.SetXdbcIsAutoincrement(builder) } } func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType} if err := g.Init(c.Base().Alloc, c.GetObjectsDbSchemas, c.GetObjectsTables, &flightSqlMetadata{}); err != nil { return nil, err } defer g.Release() catalogs, err := c.GetObjectsCatalogs(ctx, catalog) if err != nil { return nil, err } foundCatalog := false for _, catalog := range catalogs { g.AppendCatalog(catalog) foundCatalog = true } // Implementations like Dremio report no catalogs, but still have schemas if !foundCatalog && depth != adbc.ObjectDepthCatalogs { g.AppendCatalog("") } return g.Finish() } // GetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *connectionImpl) GetCurrentCatalog() (string, error) { options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } if catalog, ok := options["catalog"]; ok { if val, ok := catalog.(string); ok { return val, nil } return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string catalog %#v", catalog) } return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current catalog not supported") } // GetCurrentDbSchema implements driverbase.CurrentNamespacer. func (c *connectionImpl) GetCurrentDbSchema() (string, error) { options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } if schema, ok := options["schema"]; ok { if val, ok := schema.(string); ok { return val, nil } return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string schema %#v", schema) } return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current schema not supported") } // SetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *connectionImpl) SetCurrentCatalog(value string) error { return c.setSessionOptions(context.Background(), "catalog", value) } // SetCurrentDbSchema implements driverbase.CurrentNamespacer. func (c *connectionImpl) SetCurrentDbSchema(value string) error { return c.setSessionOptions(context.Background(), "schema", value) } func (c *connectionImpl) SetAutocommit(enabled bool) error { if enabled && c.txn == nil { // no-op don't even error if the server didn't support transactions return nil } if !c.supportInfo.transactions { return errNoTransactionSupport } ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) var err error if c.txn != nil { if err = c.txn.Commit(ctx, c.timeouts); err != nil { return adbc.Error{ Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), Code: adbc.StatusIO, } } } if enabled { c.txn = nil return nil } if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil { return adbc.Error{ Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), Code: adbc.StatusIO, } } return nil } var adbcToFlightSQLInfo = map[adbc.InfoCode]flightsql.SqlInfo{ adbc.InfoVendorName: flightsql.SqlInfoFlightSqlServerName, adbc.InfoVendorVersion: flightsql.SqlInfoFlightSqlServerVersion, adbc.InfoVendorArrowVersion: flightsql.SqlInfoFlightSqlServerArrowVersion, adbc.InfoVendorSql: flightsql.SqlInfoFlightSqlServerSql, adbc.InfoVendorSubstrait: flightsql.SqlInfoFlightSqlServerSubstrait, adbc.InfoVendorSubstraitMinVersion: flightsql.SqlInfoFlightSqlServerSubstraitMinVersion, adbc.InfoVendorSubstraitMaxVersion: flightsql.SqlInfoFlightSqlServerSubstraitMaxVersion, } func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEndpoint, clientCache gcache.Cache, opts ...grpc.CallOption) (rdr *flight.Reader, err error) { if len(endpoint.Location) == 0 { return cl.DoGet(ctx, endpoint.Ticket, opts...) } var ( cc interface{} hasFallback bool ) for _, loc := range endpoint.Location { if loc.Uri == flight.LocationReuseConnection { hasFallback = true continue } cc, err = clientCache.Get(loc.Uri) if err != nil { continue } conn := cc.(*flightsql.Client) rdr, err = conn.DoGet(ctx, endpoint.Ticket, opts...) if err != nil { continue } return } if hasFallback { return cl.DoGet(ctx, endpoint.Ticket, opts...) } return nil, err } func (c *connectionImpl) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD rawOptions, err := c.cl.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { // We're going to make a bit of a concession to backwards compatibility // here and ignore UNIMPLEMENTED or INVALID_ARGUMENT grpcStatus := grpcstatus.Convert(err) if grpcStatus.Code() == grpccodes.InvalidArgument || grpcStatus.Code() == grpccodes.Unimplemented { return map[string]interface{}{}, nil } return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions") } options := make(map[string]interface{}, len(rawOptions.SessionOptions)) for k, rawValue := range rawOptions.SessionOptions { switch v := rawValue.OptionValue.(type) { case *flightproto.SessionOptionValue_BoolValue: options[k] = v.BoolValue case *flightproto.SessionOptionValue_DoubleValue: options[k] = v.DoubleValue case *flightproto.SessionOptionValue_Int64Value: options[k] = v.Int64Value case *flightproto.SessionOptionValue_StringValue: options[k] = v.StringValue case *flightproto.SessionOptionValue_StringListValue_: if v.StringListValue.Values == nil { options[k] = make([]string, 0) } else { options[k] = v.StringListValue.Values } case nil: options[k] = nil default: return nil, adbc.Error{ Code: adbc.StatusNotImplemented, Msg: fmt.Sprintf("[FlightSQL] Unknown session option type %#v", rawValue), } } } return options, nil } func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val interface{}) error { req := flight.SetSessionOptionsRequest{} hdrs := make([]string, 0) for k, vv := range c.hdrs { for _, v := range vv { hdrs = append(hdrs, k, v) } } ctx = metadata.AppendToOutgoingContext(ctx, hdrs...) var err error req.SessionOptions, err = flight.NewSessionOptionValues(map[string]any{key: val}) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Invalid session option %s=%#v: %s", key, val, err.Error()), Code: adbc.StatusInvalidArgument, } } var header, trailer metadata.MD errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "SetSessionOptions") } if len(errors.Errors) > 0 { msg := strings.Builder{} fmt.Fprint(&msg, "[Flight SQL] Could not set option(s) ") first := true for k, v := range errors.Errors { if !first { fmt.Fprint(&msg, ", ") } first = false errmsg := "unknown error" switch v.Value { case flightproto.SetSessionOptionsResult_INVALID_NAME: errmsg = "invalid name" case flightproto.SetSessionOptionsResult_INVALID_VALUE: errmsg = "invalid value" case flightproto.SetSessionOptionsResult_ERROR: errmsg = "error setting option" } fmt.Fprintf(&msg, "'%s' (%s)", k, errmsg) } return adbc.Error{ Msg: msg.String(), Code: adbc.StatusInvalidArgument, } } return nil } func getSessionOption[T any](options map[string]interface{}, key string, defaultVal T, valueType string) (T, error) { rawValue, ok := options[key] if !ok { return defaultVal, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] unknown session option '%s'", key), Code: adbc.StatusNotFound, } } value, ok := rawValue.(T) if !ok { return defaultVal, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] session option %s=%#v is not %s value", key, rawValue, valueType), Code: adbc.StatusNotFound, } } return value, nil } func (c *connectionImpl) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) headers := c.hdrs.Get(name) if len(headers) > 0 { return headers[0], nil } return "", adbc.Error{ Msg: "[Flight SQL] unknown header", Code: adbc.StatusNotFound, } } switch key { case OptionTimeoutFetch: return c.timeouts.fetchTimeout.String(), nil case OptionTimeoutQuery: return c.timeouts.queryTimeout.String(), nil case OptionTimeoutUpdate: return c.timeouts.updateTimeout.String(), nil case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } encoded, err := json.Marshal(options) if err != nil { return "", adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()), Code: adbc.StatusInternal, } } return string(encoded), nil } switch { case strings.HasPrefix(key, OptionSessionOptionPrefix): options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } name := key[len(OptionSessionOptionPrefix):] return getSessionOption(options, name, "", "a string") case strings.HasPrefix(key, OptionBoolSessionOptionPrefix): options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } name := key[len(OptionBoolSessionOptionPrefix):] v, err := getSessionOption(options, name, false, "a boolean") if err != nil { return "", err } if v { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case strings.HasPrefix(key, OptionStringListSessionOptionPrefix): options, err := c.getSessionOptions(context.Background()) if err != nil { return "", err } name := key[len(OptionStringListSessionOptionPrefix):] v, err := getSessionOption[[]string](options, name, nil, "a string list") if err != nil { return "", err } encoded, err := json.Marshal(v) if err != nil { return "", adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Could not encode option value: %s", err.Error()), Code: adbc.StatusInternal, } } return string(encoded), nil } return "", adbc.Error{ Msg: "[Flight SQL] unknown connection option", Code: adbc.StatusNotFound, } } func (c *connectionImpl) GetOptionBytes(key string) ([]byte, error) { switch key { case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) if err != nil { return nil, err } encoded, err := json.Marshal(options) if err != nil { return nil, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()), Code: adbc.StatusInternal, } } return encoded, nil } return nil, adbc.Error{ Msg: "[Flight SQL] unknown connection option", Code: adbc.StatusNotFound, } } func (c *connectionImpl) GetOptionInt(key string) (int64, error) { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: val, err := c.GetOptionDouble(key) if err != nil { return 0, err } return int64(val), nil } if strings.HasPrefix(key, OptionSessionOptionPrefix) { options, err := c.getSessionOptions(context.Background()) if err != nil { return 0, err } name := key[len(OptionSessionOptionPrefix):] return getSessionOption(options, name, int64(0), "an integer") } return c.ConnectionImplBase.GetOptionInt(key) } func (c *connectionImpl) GetOptionDouble(key string) (float64, error) { switch key { case OptionTimeoutFetch: return c.timeouts.fetchTimeout.Seconds(), nil case OptionTimeoutQuery: return c.timeouts.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: return c.timeouts.updateTimeout.Seconds(), nil } if strings.HasPrefix(key, OptionSessionOptionPrefix) { options, err := c.getSessionOptions(context.Background()) if err != nil { return 0, err } name := key[len(OptionSessionOptionPrefix):] return getSessionOption(options, name, float64(0.0), "a floating-point") } return c.ConnectionImplBase.GetOptionDouble(key) } func (c *connectionImpl) SetOption(key, value string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) if value == "" { c.hdrs.Delete(name) } else { c.hdrs.Append(name, value) } return nil } switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeoutString(key, value) } switch { case strings.HasPrefix(key, OptionSessionOptionPrefix): name := key[len(OptionSessionOptionPrefix):] return c.setSessionOptions(context.Background(), name, value) case strings.HasPrefix(key, OptionBoolSessionOptionPrefix): name := key[len(OptionBoolSessionOptionPrefix):] switch value { case adbc.OptionValueEnabled: return c.setSessionOptions(context.Background(), name, true) case adbc.OptionValueDisabled: return c.setSessionOptions(context.Background(), name, false) default: return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] invalid boolean session option value %s=%s", name, value), Code: adbc.StatusNotImplemented, } } case strings.HasPrefix(key, OptionStringListSessionOptionPrefix): name := key[len(OptionStringListSessionOptionPrefix):] stringlist := make([]string, 0) if err := json.Unmarshal([]byte(value), &stringlist); err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] invalid string list session option value %s=%s: %s", name, value, err.Error()), Code: adbc.StatusNotImplemented, } } return c.setSessionOptions(context.Background(), name, stringlist) case strings.HasPrefix(key, OptionEraseSessionOptionPrefix): name := key[len(OptionEraseSessionOptionPrefix):] return c.setSessionOptions(context.Background(), name, nil) } return c.ConnectionImplBase.SetOption(key, value) } func (c *connectionImpl) SetOptionInt(key string, value int64) error { switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeout(key, float64(value)) } if strings.HasPrefix(key, OptionSessionOptionPrefix) { name := key[len(OptionSessionOptionPrefix):] return c.setSessionOptions(context.Background(), name, value) } return c.ConnectionImplBase.SetOptionInt(key, value) } func (c *connectionImpl) SetOptionDouble(key string, value float64) error { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: return c.timeouts.setTimeout(key, value) } if strings.HasPrefix(key, OptionSessionOptionPrefix) { name := key[len(OptionSessionOptionPrefix):] return c.setSessionOptions(context.Background(), name, value) } return c.ConnectionImplBase.SetOptionDouble(key, value) } func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { driverInfo := c.DriverInfo if len(infoCodes) == 0 { infoCodes = driverInfo.InfoSupportedCodes() } translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { if t, ok := adbcToFlightSQLInfo[code]; ok { translated = append(translated, t) } } // None of the requested info codes are available on the server, so just return the local info if len(translated) == 0 { return nil } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) // Just return local driver info if GetSqlInfo hasn't been implemented on the server if grpcstatus.Code(err) == grpccodes.Unimplemented { return nil } if err != nil { return adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") } // No error, go get the SqlInfo from the server for i, endpoint := range info.Endpoint { var header, trailer metadata.MD rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } for rdr.Next() { rec := rdr.Record() field := rec.Column(0).(*array.Uint32) info := rec.Column(1).(*array.DenseUnion) var adbcInfoCode adbc.InfoCode for i := 0; i < int(rec.NumRows()); i++ { var found bool idx := int(info.ValueOffset(i)) flightSqlInfoCode := flightsql.SqlInfo(field.Value(i)) for infocode := range adbcToFlightSQLInfo { if adbcToFlightSQLInfo[infocode] == flightSqlInfoCode { adbcInfoCode = infocode found = true break } } // SqlInfo on the server that does not have an explicit mapping to ADBC is ignored if !found { continue } var v any switch arr := info.Field(info.ChildID(i)).(type) { case *array.String: v = strings.Clone(arr.Value(idx)) case *array.Boolean: v = arr.Value(idx) default: return adbc.Error{ Msg: fmt.Sprintf("unsupported field_type %T for info_value", arr), Code: adbc.StatusInvalidArgument, } } if err := driverInfo.RegisterInfoCode(adbcInfoCode, v); err != nil { return err } } } if err := checkContext(rdr.Err(), ctx); err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } } return nil } // Helper function to read and validate a metadata stream func (c *connectionImpl) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { // use a default queueSize for the reader rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...) if err != nil { return nil, adbcFromFlightStatus(err, "DoGet") } if !rdr.Schema().Equal(expectedSchema) { rdr.Release() return nil, adbc.Error{ Msg: fmt.Sprintf("Invalid schema returned for: expected %s, got %s", expectedSchema.String(), rdr.Schema().String()), Code: adbc.StatusInternal, } } return rdr, nil } func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) { var ( header, trailer metadata.MD numCatalogs int64 ) ctx = metadata.NewOutgoingContext(ctx, c.hdrs) // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } if info.TotalRecords > 0 { numCatalogs = info.TotalRecords } header = metadata.MD{} trailer = metadata.MD{} rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } defer rdr.Release() catalogs := make([]string, 0, numCatalogs) for rdr.Next() { arr := rdr.Record().Column(0).(*array.String) for i := 0; i < arr.Len(); i++ { // XXX: force copy since accessor is unsafe catalogName := string([]byte(arr.Value(i))) catalogs = append(catalogs, catalogName) } } if err := checkContext(rdr.Err(), ctx); err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } return catalogs, nil } // Helper function to build up a map of catalogs to DB schemas func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string) (result map[string][]string, err error) { if depth == adbc.ObjectDepthCatalogs { return } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) result = make(map[string][]string) var header, trailer metadata.MD // Pre-populate the map of which schemas are in which catalogs info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } header = metadata.MD{} trailer = metadata.MD{} rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } defer rdr.Release() for rdr.Next() { // Nullable catalog := rdr.Record().Column(0).(*array.String) // Non-nullable dbSchema := rdr.Record().Column(1).(*array.String) for i := 0; i < catalog.Len(); i++ { catalogName := "" if !catalog.IsNull(i) { catalogName = string([]byte(catalog.Value(i))) } result[catalogName] = append(result[catalogName], string([]byte(dbSchema.Value(i)))) } } if err := checkContext(rdr.Err(), ctx); err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } return } func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (result internal.SchemaToTableInfo, err error) { if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) result = make(map[internal.CatalogAndSchema][]internal.TableInfo) // Pre-populate the map of which schemas are in which catalogs includeSchema := depth == adbc.ObjectDepthAll || depth == adbc.ObjectDepthColumns var header, trailer metadata.MD info, err := c.cl.GetTables(ctx, &flightsql.GetTablesOpts{ DbSchemaFilterPattern: dbSchema, TableNameFilterPattern: tableName, TableTypes: tableType, IncludeSchema: includeSchema, }, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } expectedSchema := schema_ref.Tables if includeSchema { expectedSchema = schema_ref.TablesWithIncludedSchema } header = metadata.MD{} trailer = metadata.MD{} rdr, err := c.readInfo(ctx, expectedSchema, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") } defer rdr.Release() for rdr.Next() { // Nullable catalog := rdr.Record().Column(0).(*array.String) dbSchema := rdr.Record().Column(1).(*array.String) // Non-nullable tableName := rdr.Record().Column(2).(*array.String) tableType := rdr.Record().Column(3).(*array.String) for i := 0; i < catalog.Len(); i++ { catalogName := "" dbSchemaName := "" if !catalog.IsNull(i) { catalogName = string([]byte(catalog.Value(i))) } if !dbSchema.IsNull(i) { dbSchemaName = string([]byte(dbSchema.Value(i))) } key := internal.CatalogAndSchema{ Catalog: catalogName, Schema: dbSchemaName, } var schema *arrow.Schema if includeSchema { reader, err := ipc.NewReader(bytes.NewReader(rdr.Record().Column(4).(*array.Binary).Value(i))) if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } schema = reader.Schema() reader.Release() } result[key] = append(result[key], internal.TableInfo{ Name: string([]byte(tableName.Value(i))), TableType: string([]byte(tableType.Value(i))), Schema: schema, }) } } if err := checkContext(rdr.Err(), ctx); err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } return } func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { opts := &flightsql.GetTablesOpts{ Catalog: catalog, DbSchemaFilterPattern: dbSchema, TableNameFilterPattern: &tableName, IncludeSchema: true, } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetTables(ctx, opts, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(GetTables)") } header = metadata.MD{} trailer = metadata.MD{} rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } defer rdr.Release() rec, err := rdr.Read() if err != nil { if err == io.EOF { return nil, adbc.Error{ Msg: "No table found", Code: adbc.StatusNotFound, } } return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } numRows := rec.NumRows() switch { case numRows == 0: return nil, adbc.Error{ Code: adbc.StatusNotFound, } case numRows > math.MaxInt32: return nil, adbc.Error{ Msg: "[Flight SQL] GetTableSchema cannot handle tables with number of rows > 2^31 - 1", Code: adbc.StatusNotImplemented, } } var s *arrow.Schema for i := 0; i < int(numRows); i++ { currentTableName := rec.Column(2).(*array.String).Value(i) if currentTableName == tableName { // returned schema should be // 0: catalog_name: utf8 // 1: db_schema_name: utf8 // 2: table_name: utf8 not null // 3: table_type: utf8 not null // 4: table_schema: bytes not null schemaBytes := rec.Column(4).(*array.Binary).Value(i) s, err = flight.DeserializeSchema(schemaBytes, c.db.Alloc) if err != nil { return nil, adbcFromFlightStatus(err, "GetTableSchema") } return s, nil } } return s, adbc.Error{ Msg: "[Flight SQL] GetTableSchema could not find a table with a matching schema", Code: adbc.StatusNotFound, } } // GetTableTypes returns a list of the table types in the database. // // The result is an arrow dataset with the following schema: // // Field Name | Field Type // ----------------|-------------- // table_type | utf8 not null func (c *connectionImpl) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableTypes") } return newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5) } // Commit commits any pending transactions on this connection, it should // only be used if autocommit is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. func (c *connectionImpl) Commit(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "Commit") } header = metadata.MD{} trailer = metadata.MD{} c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } // Rollback rolls back any pending transactions. Only used if autocommit // is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. func (c *connectionImpl) Rollback(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "Rollback") } header = metadata.MD{} trailer = metadata.MD{} c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } // NewStatement initializes a new statement object tied to this connection func (c *connectionImpl) NewStatement() (adbc.Statement, error) { return &statement{ alloc: c.db.Alloc, clientCache: c.clientCache, hdrs: c.hdrs.Copy(), queueSize: 5, timeouts: c.timeouts, cnxn: c, }, nil } func (c *connectionImpl) execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.Execute(ctx, query, opts...) } return c.cl.Execute(ctx, query, opts...) } func (c *connectionImpl) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSchema(ctx, query, opts...) } return c.cl.GetExecuteSchema(ctx, query, opts...) } func (c *connectionImpl) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) } return c.cl.ExecuteSubstrait(ctx, plan, opts...) } func (c *connectionImpl) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) } return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) } func (c *connectionImpl) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) } return c.cl.ExecuteUpdate(ctx, query, opts...) } func (c *connectionImpl) executeSubstraitUpdate(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteSubstraitUpdate(ctx, plan, opts...) } return c.cl.ExecuteSubstraitUpdate(ctx, plan, opts...) } func (c *connectionImpl) poll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecutePoll(ctx, query, retryDescriptor, opts...) } return c.cl.ExecutePoll(ctx, query, retryDescriptor, opts...) } func (c *connectionImpl) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } return c.cl.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } func (c *connectionImpl) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.Prepare(ctx, query, opts...) } return c.cl.Prepare(ctx, query, opts...) } func (c *connectionImpl) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.PrepareSubstrait(ctx, plan, opts...) } return c.cl.PrepareSubstrait(ctx, plan, opts...) } // Close closes this connection and releases any associated resources. func (c *connectionImpl) Close() error { if c.cl == nil { return adbc.Error{ Msg: "[Flight SQL Connection] trying to close already closed connection", Code: adbc.StatusInvalidState, } } ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) var header, trailer metadata.MD _, err := c.cl.CloseSession(ctx, &flight.CloseSessionRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { grpcStatus := grpcstatus.Convert(err) // Ignore unimplemented if grpcStatus.Code() != grpccodes.Unimplemented { // Ignore the error since server may not support it and may not properly return UNIMPLEMENTED // TODO(https://github.com/apache/arrow-adbc/issues/1243): log a proper warning c.db.Logger.Debug("failed to close session", "error", err.Error()) } } err = c.cl.Close() c.cl = nil return adbcFromFlightStatus(err, "Close") } // ReadPartition constructs a statement for a partition of a query. The // results can then be read independently using the returned RecordReader. // // A partition can be retrieved by using ExecutePartitions on a statement. func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) { var info flight.FlightInfo if err := proto.Unmarshal(serializedPartition, &info); err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInvalidArgument, } } // The driver only ever returns one endpoint. if len(info.Endpoint) != 1 { return nil, adbc.Error{ Msg: fmt.Sprintf("Invalid partition: expected 1 endpoint, got %d", len(info.Endpoint)), Code: adbc.StatusInvalidArgument, } } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts) if err != nil { return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)") } return rdr, nil } var ( _ adbc.PostInitOptions = (*connectionImpl)(nil) )