exponential/helpers/grpc/grpc.go (96 lines of code) (raw):

/* Package gRPC provides an exponential.ErrTransformer that can be used to detect non-retriable errors for gRPC calls. There is no direct support for gRPC streaming in this package. Example using just defaults: // This will retry any grpc error codes that are considered retriable. grpcErrTransform, _ := grpc.New() // Uses defaults backoff := exponential.WithErrTransformer(grpcErrTransform) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) req := &pb.HelloRequest{Name: "John"} var resp *pb.HelloReply{} err := backoff.Retry( ctx, func(ctx context.Context, r Record) error { var err error resp, err = client.SayHello(ctx, req) return err }, ) cancel() Example setting an extra code for retries: // The same as above, except we will retry on codes.DataLoss. grpcErrTransform, err := grpc.New(WithExtraCodes(codes.DataLoss)) if err != nil { // Handle error } ... // The rest is the same Example with custom message inspection: // We are going to provide a function that can inspect a proto.Message when // the client did not send an error, but there was an error sent back from the server // in the response. respHasErr := func (msg proto.Message) error { r := msg.(*pb.HelloReply) if r.Error != "" { if r.PermanentErr { // This will stop retries. return fmt.Errorf("%s: %w", r.Error, errors.ErrPermanent) } // We can still retry. return fmt.Errorf("%s", r.Error) } return nil } grpcErrTransform, err := grpc.New(WithProtoToErr(respHasErr)) if err != nil { // Handle error } backoff := exponential.WithErrTransformer(grpcErrTransform) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) req := &pb.HelloRequest{Name: "John"} var resp *pb.HelloReply{} err := backoff.Retry( ctx, func(ctx context.Context, r Record) error { a, err := grpcErrTransform.RespToErr(client.SayHello(ctx, req)) // <- Notice the call wrapper if err != nil { return err } resp = a.(*pb.HelloReply) return nil }, ) cancel() */ package grpc import ( "fmt" "reflect" "github.com/Azure/retry/internal/errors" "google.golang.org/protobuf/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) /* Transformer provides an ErrTransformer method that can be used to detect non-retriable errors. The following codes are retriable: Canceled, DeadlineExceeded, Unknown, Internal, Unavailable, ResourceExhausted. Any other code is not. */ type Transformer struct { extras map[codes.Code]bool protosToErrs []ProtoToErr } // Option is an option for the New() constructor. type Option func(t *Transformer) error // WithExtraCodes defines extra grpc status codes that are considered retriable. func WithExtraCodes(extras ...codes.Code) Option { return func(t *Transformer) error { for _, code := range extras { t.extras[code] = true } return nil } } // ProtoToErr inspects a protocol buffer message and determines if the call was really an error. // If it was not, this returns nil. type ProtoToErr func(msg proto.Message) error // WithProtoToErrs pass functions that look at protocol buffer message responses to determine if // the message actually indicates an error. func WithProtoToErrs(protosToErrs ...ProtoToErr) Option { return func(t *Transformer) error { t.protosToErrs = protosToErrs return nil } } // New returns a new Transformer. This implements exponential.ErrTransformer with the method ErrTransformer. // You can add other codes that are retriable by passing them as arguments. This list of retriable codes // are listed on Transformer. func New(options ...Option) (*Transformer, error) { t := &Transformer{ extras: map[codes.Code]bool{}, } for _, o := range options { if err := o(t); err != nil { return nil, err } } return t, nil } // ErrTransformer returns a transformer that can be used to detect non-retriable errors. // If it is non-retriable it will wrap the error with errors.ErrPermanent. func (t *Transformer) ErrTransformer(err error) error { is, code := t.isGRPCErr(err) if !is { return err } if t.isGRPCPermanent(code) { return fmt.Errorf("%w: %w", err, errors.ErrPermanent) } return err } // isGRPCErr returns true if the error is a gRPC error and the gRPC code. func (t *Transformer) isGRPCErr(err error) (bool, codes.Code) { // The gRPC status package is actually a wrapper around an internal status package. While Status is exposed // through this package, the Error type is not. So there is no great way to know if // we have a grpc Error type. That is unless we want to use the compiler linkname directive to get // at the internal status package. So instead we look to see if codes.Unknown is returned, which // is what happens when we have a non-gRPC error given to code. But since a person can set that too, // we look to see if the error has a GRPCStatus method. If it does, then it is a gRPC error. // The tests should protect us in case they change the internal Error type to remove GRPCStatus. code := status.Code(err) switch code { case codes.Unknown: // We look to see if the error has a GRPCStatus method. If it does, then it is a gRPC error. // This is not the greatest, but it is the best we can do without using the compiler directive. if _, ok := reflect.TypeOf(err).MethodByName("GRPCStatus"); ok { return true, code } return false, code case codes.OK: return false, code } return true, code } // grpcRetriable is a list of grpc status codes that are retriable. var grpcRetriable = map[codes.Code]bool{ codes.Canceled: true, codes.DeadlineExceeded: true, codes.Unknown: true, codes.Internal: true, codes.Unavailable: true, codes.ResourceExhausted: true, } // isGRPCPermanent returns true if the error is a GRPC error that is permanent. func (t *Transformer) isGRPCPermanent(code codes.Code) bool { if grpcRetriable[code] { return false } if t.extras[code] { return false } return true } // RespToErr takes a proto.Message and an error from a call from a protocol buffer client call method and // returns the Response and an error. If error != nil , this simply return the values passed. Otherwise it will inspect the // Response accord to rules passed to New() to determine if we have an error. func (t *Transformer) RespToErr(r proto.Message, err error) (proto.Message, error) { if len(t.protosToErrs) == 0 { return r, err } if err != nil { return r, err } for _, respToErr := range t.protosToErrs { if err = respToErr(r); err != nil { if errors.Is(err, errors.ErrPermanent) { return r, err } } } return r, err }