spark/client/retry.go (340 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"errors"
"io"
"math/rand"
"strings"
"time"
"github.com/apache/spark-connect-go/v35/spark/client/base"
"github.com/apache/spark-connect-go/v35/spark/client/options"
"google.golang.org/grpc/metadata"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
type RetryHandler func(error) bool
// RetryPolicy defines the parameters for a retry policy. The policy is used to determine if an
// error is retriable and how to handle retries. The policy defines the behavior of the client
// in how it backs off in case of an error and how the retries are spread out over time.
type RetryPolicy struct {
MaxRetries int32
InitialBackoff time.Duration
MaxBackoff time.Duration
BackoffMultiplier float32
Jitter time.Duration
MinJitterThreshold time.Duration
Name string
Handler RetryHandler
}
// DefaultRetryPolicy is the default retry policy used by the client. It will retry on Unavailable and
// in case the cursor has been disconnected. All other errors are considered to be not retriable.
var DefaultRetryPolicy = RetryPolicy{
MaxRetries: 15,
InitialBackoff: 50 * time.Millisecond,
MaxBackoff: 1 * time.Minute,
BackoffMultiplier: 4.0,
Jitter: 500 * time.Millisecond,
MinJitterThreshold: 2000 * time.Millisecond,
Name: "DefaultRetryPolicy",
Handler: func(e error) bool {
status := sparkerrors.FromRPCError(e)
switch status.Code {
case codes.Unavailable:
return true
case codes.Internal:
if strings.Contains(status.Message, "INVALID_CURSOR.DISCONNECTED") {
return true
}
}
return false
},
}
var TestingRetryPolicy = RetryPolicy{
MaxRetries: 5,
InitialBackoff: 0,
MaxBackoff: 1,
BackoffMultiplier: 2,
Jitter: 0,
MinJitterThreshold: 0,
Name: "TestingRetryPolicy",
Handler: func(e error) bool {
status := sparkerrors.FromRPCError(e)
switch status.Code {
case codes.Unavailable:
return true
case codes.Internal:
if strings.Contains(status.Message, "INVALID_CURSOR.DISCONNECTED") {
return true
}
}
return false
},
}
// DefaultRetryPolicyRegistry is the default set of retry policies used by the client. It contains
// all those policies that are enabled by default.
var DefaultRetryPolicyRegistry = []RetryPolicy{DefaultRetryPolicy}
// retryState is the current state of the retries for one particular RPC request. The retry
// state is independent of the retry policy.
type retryState struct {
retryCount int32
nextWait time.Duration
}
// nextAttempt calculates the next wait time for the next retry attempt. The function returns
// nil if the maximum number of retries has been exceeded, otherwise it returns the amount
// of time the caller should wait.
func (rs *retryState) nextAttempt(p RetryPolicy) *time.Duration {
if rs.retryCount >= p.MaxRetries {
return nil
}
// For the first retry pick the initial backoff of the matching policy.
if rs.retryCount == 0 {
rs.nextWait = p.InitialBackoff
}
// Adjust the retry count and calculate the next wait.
rs.retryCount++
wait := rs.nextWait
rs.nextWait = time.Duration(float32(rs.nextWait.Milliseconds())*p.BackoffMultiplier) * time.Millisecond
if rs.nextWait > p.MaxBackoff {
rs.nextWait = p.MaxBackoff
}
// Some policies define that jitter should only be applied after a particular threshold.
if wait > p.MinJitterThreshold {
wait += time.Duration(rand.Float32() * float32(p.Jitter.Milliseconds()))
}
return &wait
}
func NewRetriableSparkConnectClient(conn *grpc.ClientConn, sessionId string,
opts options.SparkClientOptions,
) base.SparkConnectRPCClient {
innerClient := proto.NewSparkConnectServiceClient(conn)
return &retriableSparkConnectClient{
client: innerClient,
sessionId: sessionId,
retryPolicies: DefaultRetryPolicyRegistry,
options: opts,
}
}
// wrapRetriableCall wraps a call to a function that returns a result and an error. The function is
// retried according to the retry policies. The function will return the result or an error if the
// retries are exceeded.
func wrapRetriableCall[Res rpcType](ctx context.Context, retryPolicies []RetryPolicy, in func(context.Context) (Res, error)) (Res, error) {
var lastErr error
var response Res
// Create the retry state for this wrapped call. The retry state captures the information about
// the wait time and how many retries to perform.
state := retryState{}
// As long as the error is retriable, we will retry the operation.
canRetry := true
for canRetry {
// Every loop iteration starts with being non-retriable.
canRetry = false
response, lastErr = in(ctx)
if lastErr != nil {
for _, h := range retryPolicies {
if h.Handler(lastErr) {
canRetry = true
wait := state.nextAttempt(h)
if wait != nil {
time.Sleep(*wait)
} else {
// If the retries are exceeded, simply return from here.
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
// Breaks out of the retry handler loop.
break
}
}
} else {
// Exit loop if no error has been received.
return response, nil
}
}
// TODO: Should this simoly return the original error?
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
type rpcType interface {
*proto.AnalyzePlanResponse | *proto.ConfigResponse | *proto.ArtifactStatusesResponse |
*proto.InterruptResponse | *proto.ReleaseExecuteResponse | *proto.ExecutePlanResponse
}
// retriableSparkConnectClient wraps the SparkConnectServiceClient implementation to
// transparently handle retries.
type retriableSparkConnectClient struct {
client base.SparkConnectRPCClient
sessionId string
// Not yet used.
// serverSideSessionId string
retryPolicies []RetryPolicy
options options.SparkClientOptions
}
func (r *retriableSparkConnectClient) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest,
opts ...grpc.CallOption,
) (proto.SparkConnectService_ExecutePlanClient, error) {
var lastErr error
// Create the retry state for this wrapped call. The retry state captures the information about
// the wait time and how many retries to perform.
state := retryState{}
// As long as the error is retriable, we will retry the operation.
canRetry := true
for canRetry {
// Every loop iteration starts with being non-retriable.
canRetry = false
response, lastErr := r.client.ExecutePlan(ctx, in, opts...)
if lastErr != nil {
for _, h := range r.retryPolicies {
if h.Handler(lastErr) {
canRetry = true
wait := state.nextAttempt(h)
if wait != nil {
time.Sleep(*wait)
} else {
// If the retries are exceeded, simply return from here.
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
// Breaks out of the retry handler loop.
break
}
}
} else {
// Exit loop if no error has been received.
rc := retriableExecutePlanClient{
context: ctx,
retryContext: &retryContext{
stream: response,
client: r,
request: in,
resultComplete: false,
retryPolicies: r.retryPolicies,
},
}
return rc, nil
}
}
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
func (r *retriableSparkConnectClient) AnalyzePlan(ctx context.Context, in *proto.AnalyzePlanRequest,
opts ...grpc.CallOption,
) (*proto.AnalyzePlanResponse, error) {
return wrapRetriableCall(ctx, r.retryPolicies, func(ctx2 context.Context) (*proto.AnalyzePlanResponse, error) {
return r.client.AnalyzePlan(ctx2, in, opts...)
})
}
func (r *retriableSparkConnectClient) Config(ctx context.Context, in *proto.ConfigRequest, opts ...grpc.CallOption) (*proto.ConfigResponse, error) {
return wrapRetriableCall(ctx, r.retryPolicies, func(ctx2 context.Context) (*proto.ConfigResponse, error) {
return r.client.Config(ctx2, in, opts...)
})
}
func (r *retriableSparkConnectClient) AddArtifacts(ctx context.Context, opts ...grpc.CallOption) (proto.SparkConnectService_AddArtifactsClient, error) {
var lastErr error
// Create the retry state for this wrapped call. The retry state captures the information about
// the wait time and how many retries to perform.
state := retryState{}
// As long as the error is retriable, we will retry the operation.
canRetry := true
for canRetry {
// Every loop iteration starts with being non-retriable.
canRetry = false
response, lastErr := r.client.AddArtifacts(ctx, opts...)
if lastErr != nil {
for _, h := range r.retryPolicies {
if h.Handler(lastErr) {
canRetry = true
wait := state.nextAttempt(h)
if wait != nil {
time.Sleep(*wait)
} else {
// If the retries are exceeded, simply return from here.
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
// Breaks out of the retry handler loop.
break
}
}
} else {
// Exit loop if no error has been received.
return response, nil
}
}
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
func (r *retriableSparkConnectClient) ArtifactStatus(ctx context.Context,
in *proto.ArtifactStatusesRequest, opts ...grpc.CallOption,
) (*proto.ArtifactStatusesResponse, error) {
return wrapRetriableCall(ctx, r.retryPolicies, func(ctx2 context.Context) (
*proto.ArtifactStatusesResponse, error,
) {
return r.client.ArtifactStatus(ctx2, in, opts...)
})
}
func (r *retriableSparkConnectClient) Interrupt(ctx context.Context, in *proto.InterruptRequest,
opts ...grpc.CallOption,
) (*proto.InterruptResponse, error) {
return wrapRetriableCall(ctx, r.retryPolicies, func(ctx2 context.Context) (*proto.InterruptResponse, error) {
return r.client.Interrupt(ctx2, in, opts...)
})
}
func (r *retriableSparkConnectClient) ReattachExecute(ctx context.Context,
in *proto.ReattachExecuteRequest, opts ...grpc.CallOption,
) (proto.SparkConnectService_ReattachExecuteClient, error) {
var lastErr error
// Create the retry state for this wrapped call. The retry state captures the information about
// the wait time and how many retries to perform.
state := retryState{}
// As long as the error is retriable, we will retry the operation.
canRetry := true
for canRetry {
// Every loop iteration starts with being non-retriable.
canRetry = false
response, lastErr := r.client.ReattachExecute(ctx, in, opts...)
if lastErr != nil {
for _, h := range r.retryPolicies {
if h.Handler(lastErr) {
canRetry = true
wait := state.nextAttempt(h)
if wait != nil {
time.Sleep(*wait)
} else {
// If the retries are exceeded, simply return from here.
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
// Breaks out of the retry handler loop.
break
}
}
} else {
// Exit loop if no error has been received.
// TODO: Re-attaching needs to be retriable as well.
return response, nil
}
}
return nil, sparkerrors.WithType(lastErr, sparkerrors.RetriesExceeded)
}
func (r *retriableSparkConnectClient) ReleaseExecute(ctx context.Context,
in *proto.ReleaseExecuteRequest, opts ...grpc.CallOption,
) (*proto.ReleaseExecuteResponse, error) {
return wrapRetriableCall(ctx, r.retryPolicies, func(ctx2 context.Context) (*proto.ReleaseExecuteResponse, error) {
return r.client.ReleaseExecute(ctx2, in, opts...)
})
}
type retryContext struct {
stream proto.SparkConnectService_ExecutePlanClient
client base.SparkConnectRPCClient
request *proto.ExecutePlanRequest
lastResponseId *string
resultComplete bool
retryPolicies []RetryPolicy
}
// retriableExecutePlanClient is a wrapper around the ExecutePlanClient that handles retries
// transparently. Since the interface has to follow the ExecutePlanClient interface, we have to
// implement all methods of the interface and follow their method receiver pattern. As the main
// methods do not implement a pointer receiver we're wrapping the variable part of the retry
// behahivor in a separate struct.
//
// In addition, we capture the original Context of the caller that is passed to the interface. While
// this is typically not a desired pattern it is the only way to make sure the same context is used
// across the retrying and underlying struct.
type retriableExecutePlanClient struct {
retryContext *retryContext
context context.Context
}
func (r retriableExecutePlanClient) Recv() (*proto.ExecutePlanResponse, error) {
return wrapRetriableCall(r.context, r.retryContext.retryPolicies, func(ctx2 context.Context) (*proto.ExecutePlanResponse, error) {
resp, err := r.retryContext.stream.Recv()
// Success, simply return the result.
if err == nil {
r.retryContext.lastResponseId = &resp.ResponseId
return resp, nil
}
// Ignore successful closure.
if errors.Is(err, io.EOF) {
return nil, err
}
// Now we have to assume that the request has failed, and we distinguish two cases: First, we have
// never received a result and in this case we simply execute the same request again. Second,
// we will send a reattach request with the same operation ID and the last response ID.
if r.retryContext.lastResponseId == nil {
// Send the request again.
rs, execErr := r.retryContext.client.ExecutePlan(ctx2, r.retryContext.request)
if execErr != nil {
return nil, execErr
}
switch stream := rs.(type) {
case retriableExecutePlanClient:
r.retryContext.stream = stream.retryContext.stream
default:
r.retryContext.stream = stream
}
return nil, err
} else {
// Send a reattach
req := &proto.ReattachExecuteRequest{
SessionId: r.retryContext.request.SessionId,
UserContext: r.retryContext.request.UserContext,
OperationId: *r.retryContext.request.OperationId,
LastResponseId: r.retryContext.lastResponseId,
}
re, execErr := r.retryContext.client.ReattachExecute(ctx2, req)
if execErr != nil {
return nil, execErr
}
switch stream := re.(type) {
case retriableExecutePlanClient:
r.retryContext.stream = stream.retryContext.stream
default:
r.retryContext.stream = stream
}
return nil, err
}
})
}
func (r retriableExecutePlanClient) Header() (metadata.MD, error) {
return r.retryContext.stream.Header()
}
func (r retriableExecutePlanClient) Trailer() metadata.MD {
return r.retryContext.stream.Trailer()
}
func (r retriableExecutePlanClient) CloseSend() error {
return r.retryContext.stream.CloseSend()
}
func (r retriableExecutePlanClient) Context() context.Context {
return r.retryContext.stream.Context()
}
func (r retriableExecutePlanClient) SendMsg(m any) error {
return r.retryContext.stream.SendMsg(m)
}
func (r retriableExecutePlanClient) RecvMsg(m any) error {
return r.retryContext.stream.RecvMsg(m)
}