v2/lockrenewer.go (167 lines of code) (raw):

package shuttle import ( "context" "errors" "fmt" "sync/atomic" "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "github.com/Azure/go-shuttle/v2/metrics/processor" ) // LockRenewer abstracts the servicebus receiver client to only expose lock renewal type LockRenewer interface { RenewMessageLock(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.RenewMessageLockOptions) error } // LockRenewalOptions configures the lock renewal. type LockRenewalOptions struct { // Interval defines the frequency at which we renew the lock on the message. Defaults to 10 seconds. Interval *time.Duration // LockRenewalTimeout is the timeout value used on the context when sending RenewMessageLock() request. // Defaults to 5 seconds if not set or 0. Defaults to Lock Expiry time if set to a negative value. LockRenewalTimeout *time.Duration // CancelMessageContextOnStop will cancel the downstream message context when the renewal handler is stopped. // Defaults to true. CancelMessageContextOnStop *bool // MetricRecorder allows to pass a custom metric recorder for the LockRenewer. // Defaults to processor.Metric instance. MetricRecorder processor.Recorder } // NewRenewLockHandler returns a middleware handler that will renew the lock on the message at the specified interval. func NewRenewLockHandler(options *LockRenewalOptions, handler Handler) HandlerFunc { interval := 10 * time.Second lockRenewalTimeout := 5 * time.Second cancelMessageContextOnStop := true metricRecorder := processor.Metric if options != nil { if options.Interval != nil { interval = *options.Interval } if options.LockRenewalTimeout != nil && *options.LockRenewalTimeout != 0 { lockRenewalTimeout = *options.LockRenewalTimeout } if options.CancelMessageContextOnStop != nil { cancelMessageContextOnStop = *options.CancelMessageContextOnStop } if options.MetricRecorder != nil { metricRecorder = options.MetricRecorder } } return func(ctx context.Context, settler MessageSettler, message *azservicebus.ReceivedMessage) { plr := &peekLockRenewer{ next: handler, lockRenewer: settler, renewalInterval: &interval, renewalTimeout: &lockRenewalTimeout, metrics: metricRecorder, cancelMessageCtxOnStop: cancelMessageContextOnStop, stopped: make(chan struct{}, 1), // buffered channel to ensure we are not blocking } renewalCtx, cancel := context.WithCancel(ctx) plr.cancelMessageCtx = cancel go plr.startPeriodicRenewal(renewalCtx, message) handler.Handle(renewalCtx, settler, message) plr.stop(renewalCtx) } } // Deprecated: use NewRenewLockHandler // NewLockRenewalHandler returns a middleware handler that will renew the lock on the message at the specified interval. func NewLockRenewalHandler(lockRenewer LockRenewer, options *LockRenewalOptions, handler Handler) HandlerFunc { return NewRenewLockHandler(options, handler) } // peekLockRenewer starts a background goroutine that renews the message lock at the given interval until Stop() is called // or until the passed in context is canceled. // it is a pass through handler if the renewalInterval is nil type peekLockRenewer struct { next Handler lockRenewer LockRenewer renewalInterval *time.Duration renewalTimeout *time.Duration metrics processor.Recorder alive atomic.Bool cancelMessageCtxOnStop bool cancelMessageCtx func() // stopped channel allows to short circuit the renewal loop // when we are already waiting on the select. // the channel is needed in addition to the boolean // to cover the case where we might have finished handling the message and called stop on the renewer // before the renewal goroutine had a chance to start. stopped chan struct{} } // stop will stop the renewal loop. if LockRenewalOptions.CancelMessageContextOnStop is set to true, it cancels the message context. func (plr *peekLockRenewer) stop(ctx context.Context) { logger := getLogger(ctx) plr.alive.Store(false) // don't send the stop signal to the loop if there is already one in the channel if len(plr.stopped) == 0 { plr.stopped <- struct{}{} } if plr.cancelMessageCtxOnStop { logger.Info("canceling message context") plr.cancelMessageCtx() } logger.Info("stopped periodic renewal") } func (plr *peekLockRenewer) isPermanent(err error) bool { var sbErr *azservicebus.Error if errors.As(err, &sbErr) { // once the lock is lost, the renewal cannot succeed. return sbErr.Code == azservicebus.CodeLockLost || sbErr.Code == azservicebus.CodeUnauthorizedAccess } return false } func (plr *peekLockRenewer) startPeriodicRenewal(ctx context.Context, message *azservicebus.ReceivedMessage) { logger := getLogger(ctx) count := 0 span := trace.SpanFromContext(ctx) for plr.alive.Store(true); plr.alive.Load(); { select { case <-time.After(*plr.renewalInterval): if !plr.alive.Load() { return } count++ err := plr.renewMessageLock(ctx, message, nil) if err != nil { // The context is canceled when the message handler returns from the processor. // This can happen if we already entered the interval case when the message processing completes. // The best we can do is log and retry on the next tick. The sdk already retries operations on recoverable network errors. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { // if the error is a context error // we stop and let the next loop iteration handle the exit. plr.stop(ctx) continue } plr.metrics.IncMessageLockRenewedFailure(message) span.RecordError(fmt.Errorf("failed to renew lock: %w", err)) // on error, we continue to the next loop iteration. // if the context is Done, we will enter the ctx.Done() case and exit the renewal. // if the error is identified as permanent, we stop the renewal. // if the error is anything else, we keep trying the renewal. if plr.isPermanent(err) { logger.Error(fmt.Sprintf("stopping periodic renewal for message: %s", message.MessageID)) plr.stop(ctx) } continue } span.AddEvent("message lock renewed", trace.WithAttributes(attribute.Int("count", count))) plr.metrics.IncMessageLockRenewedSuccess(message) case <-ctx.Done(): logger.Info("context done: stopping periodic renewal") span.AddEvent("context done: stopping message lock renewal") err := ctx.Err() if errors.Is(err, context.DeadlineExceeded) { span.RecordError(err) plr.metrics.IncMessageDeadlineReachedCount(message) } plr.stop(ctx) case <-plr.stopped: if plr.alive.Load() { logger.Info("stop signal received: exiting periodic renewal") plr.alive.Store(false) } } } } func (plr *peekLockRenewer) renewMessageLock(ctx context.Context, message *azservicebus.ReceivedMessage, options *azservicebus.RenewMessageLockOptions) error { span := trace.SpanFromContext(ctx) lockLostErr := &azservicebus.Error{Code: azservicebus.CodeLockLost} if message.LockedUntil == nil || time.Until(*message.LockedUntil) < 0 { // if the lock doesn't exist or is already expired, we should not attempt to renew it. return lockLostErr } renewalTimeout := time.Until(*message.LockedUntil) if *plr.renewalTimeout > 0 { renewalTimeout = *plr.renewalTimeout } // we should keep retrying until lock expiry or until message context is done for time.Until(*message.LockedUntil) > 0 && ctx.Err() == nil { var renewErr error func() { getLogger(ctx).Info(fmt.Sprintf("renewing lock with timeout: %s", renewalTimeout)) renewalCtx, cancel := context.WithTimeout(ctx, renewalTimeout) defer cancel() renewErr = plr.lockRenewer.RenewMessageLock(renewalCtx, message, options) }() if renewErr != nil { getLogger(ctx).Error(fmt.Sprintf("failed to renew lock: %s", renewErr)) span.AddEvent("failed to renew message lock", trace.WithAttributes(attribute.String("errorDetails", renewErr.Error()))) } // exit the renewal if no error or if we get any error other than context deadline exceeded if !errors.Is(renewErr, context.DeadlineExceeded) { return renewErr } // renewErr is context.DeadlineExceeded, increment metric and retry plr.metrics.IncMessageLockRenewalTimeoutCount(message) } // lock is expired or message context is done if ctx.Err() != nil { return ctx.Err() } return lockLostErr }