spark/sql/types/arrow.go (341 lines of code) (raw):

// // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package types import ( "bytes" "fmt" proto "github.com/apache/spark-connect-go/v35/internal/generated" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/spark-connect-go/v35/spark/sparkerrors" ) func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { result := make([]Row, table.NumRows()) // For each column in the table, read the data and convert it to an array of any. cols := make([][]any, table.NumCols()) for i := 0; i < int(table.NumCols()); i++ { chunkedColumn := table.Column(i).Data() column, err := readChunkedColumn(chunkedColumn) if err != nil { return nil, err } cols[i] = column } // Create a list of field names for the rows. fieldNames := make([]string, table.NumCols()) for i, field := range table.Schema().Fields() { fieldNames[i] = field.Name } // Create the rows: for i := 0; i < int(table.NumRows()); i++ { row := make([]any, table.NumCols()) for j := 0; j < int(table.NumCols()); j++ { row[j] = cols[j][i] } r := &rowImpl{ values: row, offsets: make(map[string]int), } for j, fieldName := range fieldNames { r.offsets[fieldName] = j } result[i] = r } return result, nil } func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) { buf := make([]any, 0) // Switch over the type t and append the values to buf. switch t { case arrow.BOOL: data := array.NewBooleanData(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.INT8: data := array.NewInt8Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.INT16: data := array.NewInt16Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.INT32: data := array.NewInt32Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.INT64: data := array.NewInt64Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.FLOAT16: data := array.NewFloat16Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.FLOAT32: data := array.NewFloat32Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.FLOAT64: data := array.NewFloat64Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.DECIMAL | arrow.DECIMAL128: data := array.NewDecimal128Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.DECIMAL256: data := array.NewDecimal256Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.STRING: data := array.NewStringData(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.BINARY: data := array.NewBinaryData(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.TIMESTAMP: data := array.NewTimestampData(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.DATE64: data := array.NewDate64Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.DATE32: data := array.NewDate32Data(data) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) } else { buf = append(buf, data.Value(i)) } } case arrow.LIST: data := array.NewListData(data) values := data.ListValues() res, err := readArrayData(values.DataType().ID(), values.Data()) if err != nil { return nil, err } for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) continue } start := data.Offsets()[i] end := data.Offsets()[i+1] // TODO: Unfortunately, this ends up being stored as a slice of slices of any. But not // the right type. buf = append(buf, res[start:end]) } case arrow.MAP: // For maps the data is stored as a list of key value pairs. So to extract the maps, // we follow the same behavior as for lists but with two sub lists. data := array.NewMapData(data) keys := data.Keys() values := data.Items() keyValues, err := readArrayData(keys.DataType().ID(), keys.Data()) if err != nil { return nil, err } valueValues, err := readArrayData(values.DataType().ID(), values.Data()) if err != nil { return nil, err } for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) continue } tmp := make(map[any]any) start := data.Offsets()[i] end := data.Offsets()[i+1] k := keyValues[start:end] v := valueValues[start:end] for j := 0; j < len(k); j++ { tmp[k[j]] = v[j] } buf = append(buf, tmp) } case arrow.STRUCT: data := array.NewStructData(data) schema := data.DataType().(*arrow.StructType) for i := 0; i < data.Len(); i++ { if data.IsNull(i) { buf = append(buf, nil) continue } tmp := make(map[string]any) for j := range data.NumField() { field := data.Field(j) fieldValues, err := readArrayData(field.DataType().ID(), field.Data()) if err != nil { return nil, err } tmp[schema.Field(j).Name] = fieldValues[i] } buf = append(buf, tmp) } default: return nil, fmt.Errorf("unsupported arrow data type %s", t.String()) } return buf, nil } func readChunkedColumn(chunked *arrow.Chunked) ([]any, error) { buf := make([]any, 0) for _, chunk := range chunked.Chunks() { data := chunk.Data() t := data.DataType().ID() values, err := readArrayData(t, data) if err != nil { return nil, err } buf = append(buf, values...) } return buf, nil } func ReadArrowBatchToRecord(data []byte, schema *StructType) (arrow.Record, error) { reader := bytes.NewReader(data) arrowReader, err := ipc.NewReader(reader) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to create arrow reader: %w", err), sparkerrors.ReadError) } defer arrowReader.Release() record, err := arrowReader.Read() record.Retain() if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to read arrow record: %w", err), sparkerrors.ReadError) } return record, nil } func arrowStructToProtoStruct(schema *arrow.StructType) *proto.DataType_Struct_ { fields := make([]*proto.DataType_StructField, schema.NumFields()) for i, field := range schema.Fields() { fields[i] = &proto.DataType_StructField{ Name: field.Name, DataType: ArrowTypeToProto(field.Type), } } return &proto.DataType_Struct_{ Struct: &proto.DataType_Struct{ Fields: fields, }, } } func ArrowTypeToProto(dataType arrow.DataType) *proto.DataType { switch dataType.ID() { case arrow.BOOL: return &proto.DataType{Kind: &proto.DataType_Boolean_{}} case arrow.INT8: return &proto.DataType{Kind: &proto.DataType_Byte_{}} case arrow.INT16: return &proto.DataType{Kind: &proto.DataType_Short_{}} case arrow.INT32: return &proto.DataType{Kind: &proto.DataType_Integer_{}} case arrow.INT64: return &proto.DataType{Kind: &proto.DataType_Long_{}} case arrow.FLOAT16: return &proto.DataType{Kind: &proto.DataType_Float_{}} case arrow.FLOAT32: return &proto.DataType{Kind: &proto.DataType_Double_{}} case arrow.FLOAT64: return &proto.DataType{Kind: &proto.DataType_Double_{}} case arrow.DECIMAL | arrow.DECIMAL128: return &proto.DataType{Kind: &proto.DataType_Decimal_{}} case arrow.DECIMAL256: return &proto.DataType{Kind: &proto.DataType_Decimal_{}} case arrow.STRING: return &proto.DataType{Kind: &proto.DataType_String_{}} case arrow.BINARY: return &proto.DataType{Kind: &proto.DataType_Binary_{}} case arrow.TIMESTAMP: return &proto.DataType{Kind: &proto.DataType_Timestamp_{}} case arrow.DATE64: return &proto.DataType{Kind: &proto.DataType_Date_{}} case arrow.LIST: return &proto.DataType{Kind: &proto.DataType_Array_{ Array: &proto.DataType_Array{ ElementType: ArrowTypeToProto(dataType.(*arrow.ListType).Elem()), }, }} case arrow.STRUCT: return &proto.DataType{Kind: arrowStructToProtoStruct(dataType.(*arrow.StructType))} default: return &proto.DataType{Kind: &proto.DataType_Unparsed_{}} } } func ArrowSchemaToProto(schema *arrow.Schema) proto.DataType_Struct { fields := make([]*proto.DataType_StructField, schema.NumFields()) for i, field := range schema.Fields() { fields[i] = &proto.DataType_StructField{ Name: field.Name, DataType: ArrowTypeToProto(field.Type), } } return proto.DataType_Struct{ Fields: fields, } }