spark/client/client.go (362 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"errors"
"fmt"
"io"
"github.com/apache/spark-connect-go/v35/spark/sql/utils"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/apache/spark-connect-go/v35/spark/client/base"
"github.com/apache/spark-connect-go/v35/spark/mocks"
"github.com/apache/spark-connect-go/v35/spark/client/options"
"github.com/google/uuid"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/spark-connect-go/v35/spark/sql/types"
"github.com/apache/spark-connect-go/v35/internal/generated"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
)
type sparkConnectClientImpl struct {
client base.SparkConnectRPCClient
metadata metadata.MD
sessionId string
opts options.SparkClientOptions
}
func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto.ExecutePlanRequest {
// Every new executin needs to get a new operation ID.
operationId := uuid.NewString()
return &proto.ExecutePlanRequest{
SessionId: s.sessionId,
Plan: plan,
UserContext: &proto.UserContext{
UserId: s.opts.UserId,
},
ClientType: &s.opts.UserAgent,
// Operation ID is needed for being able to reattach.
OperationId: &operationId,
RequestOptions: []*proto.ExecutePlanRequest_RequestOption{
{
RequestOption: &proto.ExecutePlanRequest_RequestOption_ReattachOptions{
ReattachOptions: &proto.ReattachOptions{
Reattachable: s.opts.ReattachExecution,
},
},
},
},
}
}
func (s *sparkConnectClientImpl) ExecuteCommand(ctx context.Context, plan *proto.Plan) (arrow.Table, *types.StructType, map[string]any, error) {
request := s.newExecutePlanRequest(plan)
// Check that the supplied plan is actually a command.
if plan.GetCommand() == nil {
return nil, nil, nil, sparkerrors.WithType(
fmt.Errorf("the supplied plan does not contain a command"), sparkerrors.ExecutionError)
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, request)
if err != nil {
return nil, nil, nil, sparkerrors.WithType(
fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
respHandler := NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts)
schema, table, err := respHandler.ToTable()
if err != nil {
return nil, nil, nil, err
}
return table, schema, respHandler.Properties(), nil
}
func (s *sparkConnectClientImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) (base.ExecuteResponseStream, error) {
request := s.newExecutePlanRequest(plan)
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, request)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf(
"failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
return NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts), nil
}
// Creates a new AnalyzePlanRequest with the necessary metadata.
func (s *sparkConnectClientImpl) newAnalyzePlanStub() proto.AnalyzePlanRequest {
return proto.AnalyzePlanRequest{
SessionId: s.sessionId,
UserContext: &proto.UserContext{
UserId: s.opts.UserId,
},
ClientType: &s.opts.UserAgent,
}
}
func (s *sparkConnectClientImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_Schema_{
Schema: &proto.AnalyzePlanRequest_Schema{
Plan: plan,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response, nil
}
func (s *sparkConnectClientImpl) Explain(ctx context.Context, plan *proto.Plan,
explainMode utils.ExplainMode,
) (*proto.AnalyzePlanResponse, error) {
var mode proto.AnalyzePlanRequest_Explain_ExplainMode
if explainMode == utils.ExplainModeExtended {
mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_EXTENDED
} else if explainMode == utils.ExplainModeSimple {
mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_SIMPLE
} else if explainMode == utils.ExplainModeCost {
mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_COST
} else if explainMode == utils.ExplainModeFormatted {
mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_FORMATTED
} else if explainMode == utils.ExplainModeCodegen {
mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_CODEGEN
} else {
return nil, sparkerrors.WithType(fmt.Errorf("unsupported explain mode %v",
explainMode), sparkerrors.InvalidArgumentError)
}
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_Explain_{
Explain: &proto.AnalyzePlanRequest_Explain{
Plan: plan,
ExplainMode: mode,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response, nil
}
func (s *sparkConnectClientImpl) Persist(ctx context.Context, plan *proto.Plan, storageLevel utils.StorageLevel) error {
protoLevel := utils.ToProtoStorageLevel(storageLevel)
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_Persist_{
Persist: &proto.AnalyzePlanRequest_Persist{
Relation: plan.GetRoot(),
StorageLevel: protoLevel,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
_, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return nil
}
func (s *sparkConnectClientImpl) Unpersist(ctx context.Context, plan *proto.Plan) error {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_Unpersist_{
Unpersist: &proto.AnalyzePlanRequest_Unpersist{
Relation: plan.GetRoot(),
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
_, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return nil
}
func (s *sparkConnectClientImpl) GetStorageLevel(ctx context.Context, plan *proto.Plan) (*utils.StorageLevel, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_GetStorageLevel_{
GetStorageLevel: &proto.AnalyzePlanRequest_GetStorageLevel{
Relation: plan.GetRoot(),
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
level := response.GetGetStorageLevel().StorageLevel
res := utils.FromProtoStorageLevel(level)
return &res, nil
}
func (s *sparkConnectClientImpl) SparkVersion(ctx context.Context) (string, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_SparkVersion_{
SparkVersion: &proto.AnalyzePlanRequest_SparkVersion{},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return "", sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response.GetSparkVersion().Version, nil
}
func (s *sparkConnectClientImpl) DDLParse(ctx context.Context, sql string) (*types.StructType, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_DdlParse{
DdlParse: &proto.AnalyzePlanRequest_DDLParse{
DdlString: sql,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return types.ConvertProtoDataTypeToStructType(response.GetDdlParse().Parsed)
}
func (s *sparkConnectClientImpl) SameSemantics(ctx context.Context, plan1 *proto.Plan, plan2 *proto.Plan) (bool, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_SameSemantics_{
SameSemantics: &proto.AnalyzePlanRequest_SameSemantics{
TargetPlan: plan1,
OtherPlan: plan2,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return false, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response.GetSameSemantics().GetResult(), nil
}
func (s *sparkConnectClientImpl) SemanticHash(ctx context.Context, plan *proto.Plan) (int32, error) {
request := s.newAnalyzePlanStub()
request.Analyze = &proto.AnalyzePlanRequest_SemanticHash_{
SemanticHash: &proto.AnalyzePlanRequest_SemanticHash{
Plan: plan,
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
response, err := s.client.AnalyzePlan(ctx, &request)
if se := sparkerrors.FromRPCError(err); se != nil {
return 0, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response.GetSemanticHash().GetResult(), nil
}
func (s *sparkConnectClientImpl) Config(ctx context.Context,
operation *proto.ConfigRequest_Operation,
) (*generated.ConfigResponse, error) {
request := &proto.ConfigRequest{
Operation: operation,
UserContext: &proto.UserContext{
UserId: s.opts.UserId,
},
ClientType: &s.opts.UserAgent,
}
request.SessionId = s.sessionId
resp, err := s.client.Config(ctx, request)
if err != nil {
return nil, err
}
return resp, nil
}
func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string, opts options.SparkClientOptions) base.SparkConnectClient {
var client base.SparkConnectRPCClient
if opts.ReattachExecution {
client = NewRetriableSparkConnectClient(conn, sessionId, opts)
} else {
client = generated.NewSparkConnectServiceClient(conn)
}
return &sparkConnectClientImpl{
client: client,
metadata: md,
sessionId: sessionId,
opts: opts,
}
}
// NewSparkExecutorFromClient creates a new SparkConnectClient from an existing client and is mostly
// used in testing.
func NewSparkExecutorFromClient(client base.SparkConnectRPCClient, md metadata.MD, sessionId string) base.SparkConnectClient {
return &sparkConnectClientImpl{
client: client,
metadata: md,
sessionId: sessionId,
opts: options.DefaultSparkClientOptions,
}
}
// ExecutePlanClient is the wrapper around the result of the execution of a query plan using
// Spark Connect.
type ExecutePlanClient struct {
// The GRPC stream to read the response messages.
responseStream generated.SparkConnectService_ExecutePlanClient
// The schema of the result of the operation.
schema *types.StructType
// The sessionId is ised to verify the server side session.
sessionId string
done bool
properties map[string]any
opts options.SparkClientOptions
}
func (c *ExecutePlanClient) Properties() map[string]any {
return c.properties
}
// ToTable converts the result of the execution of a query plan to an Arrow Table.
func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {
var recordBatches []arrow.Record
var arrowSchema *arrow.Schema
recordBatches = make([]arrow.Record, 0)
// Explicitly needed when tracking re-attachble execution.
c.done = false
for {
resp, err := c.responseStream.Recv()
// EOF is received when the last message has been processed and the stream
// finished normally.
if errors.Is(err, io.EOF) {
break
}
// If the error was not EOF, there might be another error.
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
// Process the message
// Check that the server returned the session ID that we were expecting
// and that it has not changed.
if resp.GetSessionId() != c.sessionId {
return c.schema, nil, sparkerrors.WithType(&sparkerrors.InvalidServerSideSessionDetailsError{
OwnSessionId: c.sessionId,
ReceivedSessionId: resp.GetSessionId(),
}, sparkerrors.InvalidServerSideSessionError)
}
// Check if the response has already the schema set and if yes, convert
// the proto DataType to a StructType.
if resp.Schema != nil {
c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema)
if err != nil {
return nil, nil, sparkerrors.WithType(err, sparkerrors.ExecutionError)
}
}
switch x := resp.ResponseType.(type) {
case *proto.ExecutePlanResponse_SqlCommandResult_:
if val := x.SqlCommandResult.GetRelation(); val != nil {
c.properties["sql_command_result"] = val
}
case *proto.ExecutePlanResponse_ArrowBatch_:
// Do nothing.
record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema)
if err != nil {
return nil, nil, err
}
arrowSchema = record.Schema()
record.Retain()
recordBatches = append(recordBatches, record)
case *proto.ExecutePlanResponse_ResultComplete_:
c.done = true
default:
// Explicitly ignore messages that we cannot process at the moment.
}
}
// Check that the result is logically complete. The result might not be complete
// because after 2 minutes the server will interrupt the connection, and we have to
// send a ReAttach execute request.
if c.opts.ReattachExecution && !c.done {
return nil, nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError)
}
// Return the schema and table.
if arrowSchema == nil {
return c.schema, nil, nil
} else {
return c.schema, array.NewTableFromRecords(arrowSchema, recordBatches), nil
}
}
func NewExecuteResponseStream(
responseClient proto.SparkConnectService_ExecutePlanClient,
sessionId string,
operationId string,
opts options.SparkClientOptions,
) base.ExecuteResponseStream {
return &ExecutePlanClient{
responseStream: responseClient,
sessionId: sessionId,
done: false,
properties: make(map[string]any),
opts: opts,
}
}
func NewTestConnectClientFromResponses(sessionId string, r ...*mocks.MockResponse) base.SparkConnectClient {
protoClient := mocks.NewProtoClientMock(r...)
stream := NewExecuteResponseStream(protoClient, sessionId, uuid.NewString(), options.DefaultSparkClientOptions)
return &mocks.TestExecutor{
Client: stream,
}
}
func NewTestConnectClientWithImmediateError(sessionId string, err error) base.SparkConnectClient {
stream := NewExecuteResponseStream(nil, sessionId, uuid.NewString(), options.DefaultSparkClientOptions)
return &mocks.TestExecutor{
Client: stream,
Err: err,
}
}