common/asyncworkflow/queue/consumer/default_consumer.go (232 lines of code) (raw):

// The MIT License (MIT) // Copyright (c) 2017-2020 Uber Technologies Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. package consumer import ( "context" "encoding/json" "fmt" "sort" "sync" "time" "go.uber.org/yarpc" "github.com/uber/cadence/.gen/go/shared" "github.com/uber/cadence/.gen/go/sqlblobs" "github.com/uber/cadence/client/frontend" "github.com/uber/cadence/common" "github.com/uber/cadence/common/backoff" "github.com/uber/cadence/common/codec" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/messaging" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/types" ) const ( defaultShutdownTimeout = 5 * time.Second defaultStartWFTimeout = 3 * time.Second defaultConcurrency = 100 ) type DefaultConsumer struct { queueID string innerConsumer messaging.Consumer logger log.Logger scope metrics.Scope frontendClient frontend.Client ctx context.Context cancelFn context.CancelFunc wg sync.WaitGroup shutdownTimeout time.Duration startWFTimeout time.Duration msgDecoder codec.BinaryEncoder concurrency int } type Option func(*DefaultConsumer) func WithConcurrency(concurrency int) Option { return func(c *DefaultConsumer) { c.concurrency = concurrency } } func New( queueID string, innerConsumer messaging.Consumer, logger log.Logger, metricsClient metrics.Client, frontendClient frontend.Client, options ...Option, ) *DefaultConsumer { ctx, cancelFn := context.WithCancel(context.Background()) c := &DefaultConsumer{ queueID: queueID, innerConsumer: innerConsumer, logger: logger.WithTags(tag.AsyncWFQueueID(queueID)), scope: metricsClient.Scope(metrics.AsyncWorkflowConsumerScope), frontendClient: frontendClient, ctx: ctx, cancelFn: cancelFn, shutdownTimeout: defaultShutdownTimeout, startWFTimeout: defaultStartWFTimeout, msgDecoder: codec.NewThriftRWEncoder(), concurrency: defaultConcurrency, } for _, opt := range options { opt(c) } return c } func (c *DefaultConsumer) Start() error { if err := c.innerConsumer.Start(); err != nil { return err } for i := 0; i < c.concurrency; i++ { c.wg.Add(1) go c.runProcessLoop() c.logger.Info("Started process loop", tag.Counter(i)) } c.logger.Info("Started consumer", tag.Dynamic("concurrency", c.concurrency)) return nil } func (c *DefaultConsumer) Stop() { c.logger.Info("Stopping consumer") c.cancelFn() c.wg.Wait() if !common.AwaitWaitGroup(&c.wg, c.shutdownTimeout) { c.logger.Warn("Consumer timed out on shutdown", tag.Dynamic("timeout", c.shutdownTimeout)) return } c.innerConsumer.Stop() c.logger.Info("Stopped consumer") } func (c *DefaultConsumer) runProcessLoop() { defer c.wg.Done() for { select { case msg, ok := <-c.innerConsumer.Messages(): if !ok { c.logger.Info("Consumer channel closed") return } c.processMessage(msg) case <-c.ctx.Done(): c.logger.Info("Consumer context done so terminating loop") return } } } func (c *DefaultConsumer) processMessage(msg messaging.Message) { logger := c.logger.WithTags(tag.Dynamic("partition", msg.Partition()), tag.Dynamic("offset", msg.Offset())) logger.Debug("Received message") sw := c.scope.StartTimer(metrics.AsyncWorkflowProcessMsgLatency) defer sw.Stop() var request sqlblobs.AsyncRequestMessage if err := c.msgDecoder.Decode(msg.Value(), &request); err != nil { logger.Error("Failed to decode message", tag.Error(err)) c.scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount) if err := msg.Nack(); err != nil { logger.Error("Failed to nack message", tag.Error(err)) } return } if err := c.processRequest(logger, &request); err != nil { logger.Error("Failed to process message", tag.Error(err)) if err := msg.Nack(); err != nil { logger.Error("Failed to nack message", tag.Error(err)) } return } if err := msg.Ack(); err != nil { logger.Error("Failed to ack message", tag.Error(err)) } logger.Debug("Processed message successfully") } func (c *DefaultConsumer) processRequest(logger log.Logger, request *sqlblobs.AsyncRequestMessage) error { scope := c.scope.Tagged(metrics.AsyncWFRequestTypeTag(request.GetType().String())) switch request.GetType() { case sqlblobs.AsyncRequestTypeStartWorkflowExecutionAsyncRequest: startWFReq, err := decodeStartWorkflowRequest(request.GetPayload(), request.GetEncoding()) if err != nil { scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount) return err } yarpcCallOpts := getYARPCOptions(request.GetHeader()) scope := scope.Tagged(metrics.DomainTag(startWFReq.GetDomain())) var resp *types.StartWorkflowExecutionResponse op := func() error { ctx, cancel := context.WithTimeout(c.ctx, c.startWFTimeout) defer cancel() resp, err = c.frontendClient.StartWorkflowExecution(ctx, startWFReq, yarpcCallOpts...) return err } if err := callFrontendWithRetries(c.ctx, op); err != nil { scope.IncCounter(metrics.AsyncWorkflowFailureByFrontendCount) return fmt.Errorf("start workflow execution failed after all attempts: %w", err) } scope.IncCounter(metrics.AsyncWorkflowSuccessCount) logger.Info("StartWorkflowExecution succeeded", tag.WorkflowID(startWFReq.GetWorkflowID()), tag.WorkflowRunID(resp.GetRunID())) case sqlblobs.AsyncRequestTypeSignalWithStartWorkflowExecutionAsyncRequest: startWFReq, err := decodeSignalWithStartWorkflowRequest(request.GetPayload(), request.GetEncoding()) if err != nil { c.scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount) return err } yarpcCallOpts := getYARPCOptions(request.GetHeader()) scope := c.scope.Tagged(metrics.DomainTag(startWFReq.GetDomain())) var resp *types.StartWorkflowExecutionResponse op := func() error { ctx, cancel := context.WithTimeout(c.ctx, c.startWFTimeout) defer cancel() resp, err = c.frontendClient.SignalWithStartWorkflowExecution(ctx, startWFReq, yarpcCallOpts...) return err } if err := callFrontendWithRetries(c.ctx, op); err != nil { scope.IncCounter(metrics.AsyncWorkflowFailureByFrontendCount) return fmt.Errorf("signal with start workflow execution failed after all attempts: %w", err) } scope.IncCounter(metrics.AsyncWorkflowSuccessCount) logger.Info("SignalWithStartWorkflowExecution succeeded", tag.WorkflowID(startWFReq.GetWorkflowID()), tag.WorkflowRunID(resp.GetRunID())) default: c.scope.IncCounter(metrics.AsyncWorkflowFailureCorruptMsgCount) return &UnsupportedRequestType{Type: request.GetType()} } return nil } func callFrontendWithRetries(ctx context.Context, op func() error) error { throttleRetry := backoff.NewThrottleRetry( backoff.WithRetryPolicy(common.CreateFrontendServiceRetryPolicy()), backoff.WithRetryableError(common.IsServiceTransientError), ) return throttleRetry.Do(ctx, op) } func getYARPCOptions(header *shared.Header) []yarpc.CallOption { if header == nil || header.GetFields() == nil { return nil } // sort the header fields to make the tests deterministic fields := header.GetFields() sortedKeys := make([]string, 0, len(fields)) for k := range fields { sortedKeys = append(sortedKeys, k) } sort.Strings(sortedKeys) var opts []yarpc.CallOption for _, k := range sortedKeys { opts = append(opts, yarpc.WithHeader(k, string(fields[k]))) } return opts } func decodeStartWorkflowRequest(payload []byte, encoding string) (*types.StartWorkflowExecutionRequest, error) { if encoding != string(common.EncodingTypeJSON) { return nil, &UnsupportedEncoding{EncodingType: encoding} } var startRequest types.StartWorkflowExecutionAsyncRequest if err := json.Unmarshal(payload, &startRequest); err != nil { return nil, err } return startRequest.StartWorkflowExecutionRequest, nil } func decodeSignalWithStartWorkflowRequest(payload []byte, encoding string) (*types.SignalWithStartWorkflowExecutionRequest, error) { if encoding != string(common.EncodingTypeJSON) { return nil, &UnsupportedEncoding{EncodingType: encoding} } var startRequest types.SignalWithStartWorkflowExecutionAsyncRequest if err := json.Unmarshal(payload, &startRequest); err != nil { return nil, err } return startRequest.SignalWithStartWorkflowExecutionRequest, nil }