client/sql/sparksession.go (140 lines of code) (raw):

// // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sql import ( "context" "errors" "fmt" "github.com/apache/spark-connect-go/v_3_4/client/channel" proto "github.com/apache/spark-connect-go/v_3_4/internal/generated" "github.com/google/uuid" "google.golang.org/grpc/metadata" "io" ) var SparkSession sparkSessionBuilderEntrypoint type sparkSession interface { Read() DataFrameReader Sql(query string) (DataFrame, error) Stop() error } type sparkSessionBuilderEntrypoint struct { Builder SparkSessionBuilder } type SparkSessionBuilder struct { connectionString string } func (s SparkSessionBuilder) Remote(connectionString string) SparkSessionBuilder { copy := s copy.connectionString = connectionString return copy } func (s SparkSessionBuilder) Build() (sparkSession, error) { cb, err := channel.NewBuilder(s.connectionString) if err != nil { return nil, fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err) } conn, err := cb.Build() if err != nil { return nil, fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err) } // Add metadata to the request. meta := metadata.MD{} for k, v := range cb.Headers { meta[k] = append(meta[k], v) } client := proto.NewSparkConnectServiceClient(conn) return &sparkSessionImpl{ sessionId: uuid.NewString(), client: client, metadata: meta, }, nil } type sparkSessionImpl struct { sessionId string client proto.SparkConnectServiceClient metadata metadata.MD } func (s *sparkSessionImpl) Read() DataFrameReader { return &dataFrameReaderImpl{ sparkSession: s, } } func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ CommandType: &proto.Command_SqlCommand{ SqlCommand: &proto.SqlCommand{ Sql: query, }, }, }, }, } responseClient, err := s.executePlan(plan) if err != nil { return nil, fmt.Errorf("failed to execute sql: %s: %w", query, err) } for { response, err := responseClient.Recv() if err != nil { return nil, fmt.Errorf("failed to receive ExecutePlan response: %w", err) } sqlCommandResult := response.GetSqlCommandResult() if sqlCommandResult == nil { continue } return &dataFrameImpl{ sparkSession: s, relation: sqlCommandResult.GetRelation(), }, nil } return nil, fmt.Errorf("failed to get SqlCommandResult in ExecutePlan response") } func (s *sparkSessionImpl) Stop() error { return nil } func (s *sparkSessionImpl) executePlan(plan *proto.Plan) (proto.SparkConnectService_ExecutePlanClient, error) { request := proto.ExecutePlanRequest{ SessionId: s.sessionId, Plan: plan, UserContext: &proto.UserContext{ UserId: "na", }, } // Append the other items to the request. ctx := metadata.NewOutgoingContext(context.Background(), s.metadata) executePlanClient, err := s.client.ExecutePlan(ctx, &request) if err != nil { return nil, fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err) } return executePlanClient, nil } func (s *sparkSessionImpl) analyzePlan(plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { request := proto.AnalyzePlanRequest{ SessionId: s.sessionId, Analyze: &proto.AnalyzePlanRequest_Schema_{ Schema: &proto.AnalyzePlanRequest_Schema{ Plan: plan, }, }, UserContext: &proto.UserContext{ UserId: "na", }, } // Append the other items to the request. ctx := metadata.NewOutgoingContext(context.Background(), s.metadata) response, err := s.client.AnalyzePlan(ctx, &request) if err != nil { return nil, fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err) } return response, nil } // consumeExecutePlanClient reads through the returned GRPC stream from Spark Connect Driver. It will // discard the returned data if there is no error. This is necessary for handling GRPC response for // saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame. // If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. func consumeExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) error { for { _, err := responseClient.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil } else { return fmt.Errorf("failed to receive plan execution response: %w", err) } } } return nil }