arrow/flight/flightsql/example/sqlite_tables_schema_batch_reader.go (154 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //go:build go1.18 // +build go1.18 package example import ( "context" "database/sql" "strings" "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/memory" sqlite3 "modernc.org/sqlite/lib" ) type SqliteTablesSchemaBatchReader struct { refCount atomic.Int64 mem memory.Allocator ctx context.Context rdr array.RecordReader stmt *sql.Stmt schemaBldr *array.BinaryBuilder record arrow.Record err error } func NewSqliteTablesSchemaBatchReader(ctx context.Context, mem memory.Allocator, rdr array.RecordReader, db *sql.DB, mainQuery string) (*SqliteTablesSchemaBatchReader, error) { schemaQuery := `SELECT table_name, name, type, [notnull] FROM pragma_table_info(table_name) JOIN (` + mainQuery + `) WHERE table_name = ?` stmt, err := db.PrepareContext(ctx, schemaQuery) if err != nil { rdr.Release() return nil, err } stsbr := &SqliteTablesSchemaBatchReader{ ctx: ctx, rdr: rdr, stmt: stmt, mem: mem, schemaBldr: array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary), } stsbr.refCount.Add(1) return stsbr, nil } func (s *SqliteTablesSchemaBatchReader) Err() error { return s.err } func (s *SqliteTablesSchemaBatchReader) Retain() { s.refCount.Add(1) } func (s *SqliteTablesSchemaBatchReader) Release() { debug.Assert(s.refCount.Load() > 0, "too many releases") if s.refCount.Add(-1) == 0 { s.rdr.Release() s.stmt.Close() s.schemaBldr.Release() if s.record != nil { s.record.Release() s.record = nil } } } func (s *SqliteTablesSchemaBatchReader) Schema() *arrow.Schema { fields := append(s.rdr.Schema().Fields(), arrow.Field{Name: "table_schema", Type: arrow.BinaryTypes.Binary}) return arrow.NewSchema(fields, nil) } func (s *SqliteTablesSchemaBatchReader) Record() arrow.Record { return s.record } func getSqlTypeFromTypeName(sqltype string) int { if sqltype == "" { return sqlite3.SQLITE_NULL } sqltype = strings.ToLower(sqltype) if strings.HasPrefix(sqltype, "varchar") || strings.HasPrefix(sqltype, "char") { return sqlite3.SQLITE_TEXT } switch sqltype { case "int", "integer": return sqlite3.SQLITE_INTEGER case "real": return sqlite3.SQLITE_FLOAT case "blob": return sqlite3.SQLITE_BLOB case "text", "date": return sqlite3.SQLITE_TEXT default: return sqlite3.SQLITE_NULL } } func getPrecisionFromCol(sqltype int) int { switch sqltype { case sqlite3.SQLITE_INTEGER: return 10 case sqlite3.SQLITE_FLOAT: return 15 } return 0 } func getColumnMetadata(bldr *flightsql.ColumnMetadataBuilder, sqltype int, table string) arrow.Metadata { defer bldr.Clear() bldr.Scale(15).IsReadOnly(false).IsAutoIncrement(false) if table != "" { bldr.TableName(table) } switch sqltype { case sqlite3.SQLITE_TEXT, sqlite3.SQLITE_BLOB: default: bldr.Precision(int32(getPrecisionFromCol(sqltype))) } return bldr.Metadata() } func (s *SqliteTablesSchemaBatchReader) Next() bool { if s.record != nil { s.record.Release() s.record = nil } if !s.rdr.Next() { return false } rec := s.rdr.Record() tableNameArr := rec.Column(rec.Schema().FieldIndices("table_name")[0]).(*array.String) bldr := flightsql.NewColumnMetadataBuilder() columnFields := make([]arrow.Field, 0) for i := 0; i < tableNameArr.Len(); i++ { table := tableNameArr.Value(i) rows, err := s.stmt.QueryContext(s.ctx, table) if err != nil { s.err = err return false } var tableName, name, typ string var nn int for rows.Next() { if err := rows.Scan(&tableName, &name, &typ, &nn); err != nil { rows.Close() s.err = err return false } columnFields = append(columnFields, arrow.Field{ Name: name, Type: getArrowTypeFromString(typ), Nullable: nn == 0, Metadata: getColumnMetadata(bldr, getSqlTypeFromTypeName(typ), tableName), }) } rows.Close() if rows.Err() != nil { s.err = rows.Err() return false } val := flight.SerializeSchema(arrow.NewSchema(columnFields, nil), s.mem) s.schemaBldr.Append(val) columnFields = columnFields[:0] } schemaCol := s.schemaBldr.NewArray() defer schemaCol.Release() s.record = array.NewRecord(s.Schema(), append(rec.Columns(), schemaCol), rec.NumRows()) return true }