client/sql/dataframe.go (434 lines of code) (raw):

// // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sql import ( "bytes" "errors" "fmt" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" proto "github.com/apache/spark-connect-go/v_3_4/internal/generated" "io" ) // DataFrame is a wrapper for data frame, representing a distributed collection of data row. type DataFrame interface { // Show prints out data frame data. Show(numRows int, truncate bool) error // Schema returns the schema for the current data frame. Schema() (*StructType, error) // Collect returns the data rows of the current data frame. Collect() ([]Row, error) // Write returns a data frame writer, which could be used to save data frame to supported storage. Write() DataFrameWriter // CreateTempView creates or replaces a temporary view. CreateTempView(viewName string, replace bool, global bool) error // Repartition re-partitions a data frame. Repartition(numPartitions int, columns []string) (DataFrame, error) // RepartitionByRange re-partitions a data frame by range partition. RepartitionByRange(numPartitions int, columns []RangePartitionColumn) (DataFrame, error) } type RangePartitionColumn struct { Name string Descending bool } // dataFrameImpl is an implementation of DataFrame interface. type dataFrameImpl struct { sparkSession *sparkSessionImpl relation *proto.Relation // TODO change to proto.Plan? } func (df *dataFrameImpl) Show(numRows int, truncate bool) error { truncateValue := 0 if truncate { truncateValue = 20 } vertical := false plan := &proto.Plan{ OpType: &proto.Plan_Root{ Root: &proto.Relation{ Common: &proto.RelationCommon{ PlanId: newPlanId(), }, RelType: &proto.Relation_ShowString{ ShowString: &proto.ShowString{ Input: df.relation, NumRows: int32(numRows), Truncate: int32(truncateValue), Vertical: vertical, }, }, }, }, } responseClient, err := df.sparkSession.executePlan(plan) if err != nil { return fmt.Errorf("failed to show dataframe: %w", err) } for { response, err := responseClient.Recv() if err != nil { return fmt.Errorf("failed to receive show response: %w", err) } arrowBatch := response.GetArrowBatch() if arrowBatch == nil { continue } err = showArrowBatch(arrowBatch) if err != nil { return err } return nil } return fmt.Errorf("did not get arrow batch in response") } func (df *dataFrameImpl) Schema() (*StructType, error) { response, err := df.sparkSession.analyzePlan(df.createPlan()) if err != nil { return nil, fmt.Errorf("failed to analyze plan: %w", err) } responseSchema := response.GetSchema().Schema result := convertProtoDataTypeToStructType(responseSchema) return result, nil } func (df *dataFrameImpl) Collect() ([]Row, error) { responseClient, err := df.sparkSession.executePlan(df.createPlan()) if err != nil { return nil, fmt.Errorf("failed to execute plan: %w", err) } var schema *StructType var allRows []Row for { response, err := responseClient.Recv() if err != nil { if errors.Is(err, io.EOF) { return allRows, nil } else { return nil, fmt.Errorf("failed to receive plan execution response: %w", err) } } dataType := response.GetSchema() if dataType != nil { schema = convertProtoDataTypeToStructType(dataType) continue } arrowBatch := response.GetArrowBatch() if arrowBatch == nil { continue } rowBatch, err := readArrowBatchData(arrowBatch.Data, schema) if err != nil { return nil, err } if allRows == nil { allRows = make([]Row, 0, len(rowBatch)) } allRows = append(allRows, rowBatch...) } return allRows, nil } func (df *dataFrameImpl) Write() DataFrameWriter { writer := dataFrameWriterImpl{ sparkSession: df.sparkSession, relation: df.relation, } return &writer } func (df *dataFrameImpl) CreateTempView(viewName string, replace bool, global bool) error { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ CommandType: &proto.Command_CreateDataframeView{ CreateDataframeView: &proto.CreateDataFrameViewCommand{ Input: df.relation, Name: viewName, Replace: replace, IsGlobal: global, }, }, }, }, } responseClient, err := df.sparkSession.executePlan(plan) if err != nil { return fmt.Errorf("failed to create temp view %s: %w", viewName, err) } return consumeExecutePlanClient(responseClient) } func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) (DataFrame, error) { var partitionExpressions []*proto.Expression if columns != nil { partitionExpressions = make([]*proto.Expression, 0, len(columns)) for _, c := range columns { expr := &proto.Expression{ ExprType: &proto.Expression_UnresolvedAttribute_{ UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{ UnparsedIdentifier: c, }, }, } partitionExpressions = append(partitionExpressions, expr) } } return df.repartitionByExpressions(numPartitions, partitionExpressions) } func (df *dataFrameImpl) RepartitionByRange(numPartitions int, columns []RangePartitionColumn) (DataFrame, error) { var partitionExpressions []*proto.Expression if columns != nil { partitionExpressions = make([]*proto.Expression, 0, len(columns)) for _, c := range columns { columnExpr := &proto.Expression{ ExprType: &proto.Expression_UnresolvedAttribute_{ UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{ UnparsedIdentifier: c.Name, }, }, } direction := proto.Expression_SortOrder_SORT_DIRECTION_ASCENDING if c.Descending { direction = proto.Expression_SortOrder_SORT_DIRECTION_DESCENDING } sortExpr := &proto.Expression{ ExprType: &proto.Expression_SortOrder_{ SortOrder: &proto.Expression_SortOrder{ Child: columnExpr, Direction: direction, }, }, } partitionExpressions = append(partitionExpressions, sortExpr) } } return df.repartitionByExpressions(numPartitions, partitionExpressions) } func (df *dataFrameImpl) createPlan() *proto.Plan { return &proto.Plan{ OpType: &proto.Plan_Root{ Root: &proto.Relation{ Common: &proto.RelationCommon{ PlanId: newPlanId(), }, RelType: df.relation.RelType, }, }, } } func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionExpressions []*proto.Expression) (DataFrame, error) { var numPartitionsPointerValue *int32 if numPartitions != 0 { int32Value := int32(numPartitions) numPartitionsPointerValue = &int32Value } df.relation.GetRepartitionByExpression() newRelation := &proto.Relation{ Common: &proto.RelationCommon{ PlanId: newPlanId(), }, RelType: &proto.Relation_RepartitionByExpression{ RepartitionByExpression: &proto.RepartitionByExpression{ Input: df.relation, NumPartitions: numPartitionsPointerValue, PartitionExprs: partitionExpressions, }, }, } return &dataFrameImpl{ sparkSession: df.sparkSession, relation: newRelation, }, nil } func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch) error { return showArrowBatchData(arrowBatch.Data) } func showArrowBatchData(data []byte) error { rows, err := readArrowBatchData(data, nil) if err != nil { return err } for _, row := range rows { values, err := row.Values() if err != nil { return fmt.Errorf("failed to get values in the row: %w", err) } fmt.Println(values...) } return nil } func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) { reader := bytes.NewReader(data) arrowReader, err := ipc.NewReader(reader) if err != nil { return nil, fmt.Errorf("failed to create arrow reader: %w", err) } defer arrowReader.Release() var rows []Row for { record, err := arrowReader.Read() if err != nil { if errors.Is(err, io.EOF) { return rows, nil } else { return nil, fmt.Errorf("failed to read arrow: %w", err) } } values, err := readArrowRecord(record) if err != nil { return nil, err } numRows := int(record.NumRows()) if rows == nil { rows = make([]Row, 0, numRows) } for _, v := range values { row := &GenericRowWithSchema{ schema: schema, values: v, } rows = append(rows, row) } hasNext := arrowReader.Next() if !hasNext { break } } return rows, nil } // readArrowRecordColumn reads all values from arrow record and return [][]any func readArrowRecord(record arrow.Record) ([][]any, error) { numRows := record.NumRows() numColumns := int(record.NumCols()) values := make([][]any, numRows) for i := range values { values[i] = make([]any, numColumns) } for columnIndex := 0; columnIndex < numColumns; columnIndex++ { err := readArrowRecordColumn(record, columnIndex, values) if err != nil { return nil, err } } return values, nil } // readArrowRecordColumn reads all values in a column and stores them in values func readArrowRecordColumn(record arrow.Record, columnIndex int, values [][]any) error { numRows := int(record.NumRows()) columnData := record.Column(columnIndex).Data() dataTypeId := columnData.DataType().ID() switch dataTypeId { case arrow.BOOL: vector := array.NewBooleanData(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.INT8: vector := array.NewInt8Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.INT16: vector := array.NewInt16Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.INT32: vector := array.NewInt32Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.INT64: vector := array.NewInt64Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.FLOAT16: vector := array.NewFloat16Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.FLOAT32: vector := array.NewFloat32Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.FLOAT64: vector := array.NewFloat64Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.DECIMAL | arrow.DECIMAL128: vector := array.NewDecimal128Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.DECIMAL256: vector := array.NewDecimal256Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.STRING: vector := array.NewStringData(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.BINARY: vector := array.NewBinaryData(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.TIMESTAMP: vector := array.NewTimestampData(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } case arrow.DATE64: vector := array.NewDate64Data(columnData) for rowIndex := 0; rowIndex < numRows; rowIndex++ { values[rowIndex][columnIndex] = vector.Value(rowIndex) } default: return fmt.Errorf("unsupported arrow data type %s in column %d", dataTypeId.String(), columnIndex) } return nil } func convertProtoDataTypeToStructType(input *proto.DataType) *StructType { dataTypeStruct := input.GetStruct() if dataTypeStruct == nil { panic("dataType.GetStruct() is nil") } return &StructType{ Fields: convertProtoStructFields(dataTypeStruct.Fields), } } func convertProtoStructFields(input []*proto.DataType_StructField) []StructField { result := make([]StructField, len(input)) for i, f := range input { result[i] = convertProtoStructField(f) } return result } func convertProtoStructField(field *proto.DataType_StructField) StructField { return StructField{ Name: field.Name, DataType: convertProtoDataTypeToDataType(field.DataType), } } // convertProtoDataTypeToDataType converts protobuf data type to Spark connect sql data type func convertProtoDataTypeToDataType(input *proto.DataType) DataType { switch v := input.GetKind().(type) { case *proto.DataType_Boolean_: return BooleanType{} case *proto.DataType_Byte_: return ByteType{} case *proto.DataType_Short_: return ShortType{} case *proto.DataType_Integer_: return IntegerType{} case *proto.DataType_Long_: return LongType{} case *proto.DataType_Float_: return FloatType{} case *proto.DataType_Double_: return DoubleType{} case *proto.DataType_Decimal_: return DecimalType{} case *proto.DataType_String_: return StringType{} case *proto.DataType_Binary_: return BinaryType{} case *proto.DataType_Timestamp_: return TimestampType{} case *proto.DataType_TimestampNtz: return TimestampNtzType{} case *proto.DataType_Date_: return DateType{} default: return UnsupportedType{ TypeInfo: v, } } }