service/matching/handler.go (340 lines of code) (raw):

// Copyright (c) 2017 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination handler_mock.go -package matching github.com/uber/cadence/service/matching Handler //go:generate gowrap gen -g -p . -i Handler -t ../templates/grpc.tmpl -o ./grpc_handler_generated.go -v handler=GRPC -v package=matchingv1 -v path=github.com/uber/cadence/.gen/proto/matching/v1 -v prefix=Matching //go:generate gowrap gen -g -p ../../.gen/go/matching/matchingserviceserver -i Interface -t ../templates/thrift.tmpl -o ./thrift_handler_generated.go -v handler=Thrift -v prefix=Matching package matching import ( "context" "sync" "time" "github.com/uber/cadence/common" "github.com/uber/cadence/common/cache" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/quotas" "github.com/uber/cadence/common/types" ) var _ Handler = (*handlerImpl)(nil) type ( // Handler interface for matching service Handler interface { common.Daemon Health(context.Context) (*types.HealthStatus, error) AddActivityTask(context.Context, *types.AddActivityTaskRequest) error AddDecisionTask(context.Context, *types.AddDecisionTaskRequest) error CancelOutstandingPoll(context.Context, *types.CancelOutstandingPollRequest) error DescribeTaskList(context.Context, *types.MatchingDescribeTaskListRequest) (*types.DescribeTaskListResponse, error) ListTaskListPartitions(context.Context, *types.MatchingListTaskListPartitionsRequest) (*types.ListTaskListPartitionsResponse, error) GetTaskListsByDomain(context.Context, *types.GetTaskListsByDomainRequest) (*types.GetTaskListsByDomainResponse, error) PollForActivityTask(context.Context, *types.MatchingPollForActivityTaskRequest) (*types.PollForActivityTaskResponse, error) PollForDecisionTask(context.Context, *types.MatchingPollForDecisionTaskRequest) (*types.MatchingPollForDecisionTaskResponse, error) QueryWorkflow(context.Context, *types.MatchingQueryWorkflowRequest) (*types.QueryWorkflowResponse, error) RespondQueryTaskCompleted(context.Context, *types.MatchingRespondQueryTaskCompletedRequest) error } // handlerImpl is an implementation for matching service independent of wire protocol handlerImpl struct { engine Engine metricsClient metrics.Client startWG sync.WaitGroup userRateLimiter quotas.Policy workerRateLimiter quotas.Policy logger log.Logger throttledLogger log.Logger domainCache cache.DomainCache } ) var ( errMatchingHostThrottle = &types.ServiceBusyError{Message: "Matching host rps exceeded"} ) // NewHandler creates a thrift handler for the matching service func NewHandler( engine Engine, config *Config, domainCache cache.DomainCache, metricsClient metrics.Client, logger log.Logger, throttledLogger log.Logger, ) Handler { handler := &handlerImpl{ metricsClient: metricsClient, userRateLimiter: quotas.NewMultiStageRateLimiter( quotas.NewDynamicRateLimiter(config.UserRPS.AsFloat64()), quotas.NewCollection(quotas.NewFallbackDynamicRateLimiterFactory( config.DomainUserRPS, config.UserRPS, )), ), workerRateLimiter: quotas.NewMultiStageRateLimiter( quotas.NewDynamicRateLimiter(config.WorkerRPS.AsFloat64()), quotas.NewCollection(quotas.NewFallbackDynamicRateLimiterFactory( config.DomainWorkerRPS, config.WorkerRPS, )), ), engine: engine, logger: logger, throttledLogger: throttledLogger, domainCache: domainCache, } // prevent us from trying to serve requests before matching engine is started and ready handler.startWG.Add(1) return handler } // Start starts the handler func (h *handlerImpl) Start() { h.startWG.Done() } // Stop stops the handler func (h *handlerImpl) Stop() { h.engine.Stop() } // Health is for health check func (h *handlerImpl) Health(ctx context.Context) (*types.HealthStatus, error) { h.startWG.Wait() h.logger.Debug("Matching service health check endpoint reached.") hs := &types.HealthStatus{Ok: true, Msg: "matching good"} return hs, nil } func (h *handlerImpl) newHandlerContext( ctx context.Context, domainName string, taskList *types.TaskList, scope int, ) *handlerContext { return newHandlerContext( ctx, domainName, taskList, h.metricsClient, scope, h.logger, ) } // AddActivityTask - adds an activity task. func (h *handlerImpl) AddActivityTask( ctx context.Context, request *types.AddActivityTaskRequest, ) (retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() startT := time.Now() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetTaskList(), metrics.MatchingAddActivityTaskScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if request.GetForwardedFrom() != "" { hCtx.scope.IncCounter(metrics.ForwardedPerTaskListCounter) } if ok := h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return hCtx.handleErr(errMatchingHostThrottle) } syncMatch, err := h.engine.AddActivityTask(hCtx, request) if syncMatch { hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) } return hCtx.handleErr(err) } // AddDecisionTask - adds a decision task. func (h *handlerImpl) AddDecisionTask( ctx context.Context, request *types.AddDecisionTaskRequest, ) (retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() startT := time.Now() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetTaskList(), metrics.MatchingAddDecisionTaskScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if request.GetForwardedFrom() != "" { hCtx.scope.IncCounter(metrics.ForwardedPerTaskListCounter) } if ok := h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return hCtx.handleErr(errMatchingHostThrottle) } syncMatch, err := h.engine.AddDecisionTask(hCtx, request) if syncMatch { hCtx.scope.RecordTimer(metrics.SyncMatchLatencyPerTaskList, time.Since(startT)) } return hCtx.handleErr(err) } // PollForActivityTask - long poll for an activity task. func (h *handlerImpl) PollForActivityTask( ctx context.Context, request *types.MatchingPollForActivityTaskRequest, ) (resp *types.PollForActivityTaskResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetPollRequest().GetTaskList(), metrics.MatchingPollForActivityTaskScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if request.GetForwardedFrom() != "" { hCtx.scope.IncCounter(metrics.ForwardedPerTaskListCounter) } if ok := h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } if _, err := common.ValidateLongPollContextTimeoutIsSet(ctx, "PollForActivityTask", h.throttledLogger, ); err != nil { return nil, hCtx.handleErr(err) } response, err := h.engine.PollForActivityTask(hCtx, request) return response, hCtx.handleErr(err) } // PollForDecisionTask - long poll for a decision task. func (h *handlerImpl) PollForDecisionTask( ctx context.Context, request *types.MatchingPollForDecisionTaskRequest, ) (resp *types.MatchingPollForDecisionTaskResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetPollRequest().GetTaskList(), metrics.MatchingPollForDecisionTaskScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if request.GetForwardedFrom() != "" { hCtx.scope.IncCounter(metrics.ForwardedPerTaskListCounter) } if ok := h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } if _, err := common.ValidateLongPollContextTimeoutIsSet( ctx, "PollForDecisionTask", h.throttledLogger, ); err != nil { return nil, hCtx.handleErr(err) } response, err := h.engine.PollForDecisionTask(hCtx, request) return response, hCtx.handleErr(err) } // QueryWorkflow queries a given workflow synchronously and return the query result. func (h *handlerImpl) QueryWorkflow( ctx context.Context, request *types.MatchingQueryWorkflowRequest, ) (resp *types.QueryWorkflowResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetTaskList(), metrics.MatchingQueryWorkflowScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if request.GetForwardedFrom() != "" { hCtx.scope.IncCounter(metrics.ForwardedPerTaskListCounter) } if ok := h.userRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } response, err := h.engine.QueryWorkflow(hCtx, request) return response, hCtx.handleErr(err) } // RespondQueryTaskCompleted responds a query task completed func (h *handlerImpl) RespondQueryTaskCompleted( ctx context.Context, request *types.MatchingRespondQueryTaskCompletedRequest, ) (retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetTaskList(), metrics.MatchingRespondQueryTaskCompletedScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() // Count the request in the RPS, but we still accept it even if RPS is exceeded h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}) err := h.engine.RespondQueryTaskCompleted(hCtx, request) return hCtx.handleErr(err) } // CancelOutstandingPoll is used to cancel outstanding pollers func (h *handlerImpl) CancelOutstandingPoll(ctx context.Context, request *types.CancelOutstandingPollRequest) (retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetTaskList(), metrics.MatchingCancelOutstandingPollScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() // Count the request in the RPS, but we still accept it even if RPS is exceeded h.workerRateLimiter.Allow(quotas.Info{Domain: domainName}) err := h.engine.CancelOutstandingPoll(hCtx, request) return hCtx.handleErr(err) } // DescribeTaskList returns information about the target tasklist, right now this API returns the // pollers which polled this tasklist in last few minutes. If includeTaskListStatus field is true, // it will also return status of tasklist's ackManager (readLevel, ackLevel, backlogCountHint and taskIDBlock). func (h *handlerImpl) DescribeTaskList( ctx context.Context, request *types.MatchingDescribeTaskListRequest, ) (resp *types.DescribeTaskListResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() domainName := h.domainName(request.GetDomainUUID()) hCtx := h.newHandlerContext( ctx, domainName, request.GetDescRequest().GetTaskList(), metrics.MatchingDescribeTaskListScope, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if ok := h.userRateLimiter.Allow(quotas.Info{Domain: domainName}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } response, err := h.engine.DescribeTaskList(hCtx, request) return response, hCtx.handleErr(err) } // ListTaskListPartitions returns information about partitions for a taskList func (h *handlerImpl) ListTaskListPartitions( ctx context.Context, request *types.MatchingListTaskListPartitionsRequest, ) (resp *types.ListTaskListPartitionsResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() hCtx := newHandlerContext( ctx, request.GetDomain(), request.GetTaskList(), h.metricsClient, metrics.MatchingListTaskListPartitionsScope, h.logger, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if ok := h.userRateLimiter.Allow(quotas.Info{Domain: request.GetDomain()}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } response, err := h.engine.ListTaskListPartitions(hCtx, request) return response, hCtx.handleErr(err) } // GetTaskListsByDomain returns information about partitions for a taskList func (h *handlerImpl) GetTaskListsByDomain( ctx context.Context, request *types.GetTaskListsByDomainRequest, ) (resp *types.GetTaskListsByDomainResponse, retError error) { defer func() { log.CapturePanic(recover(), h.logger, &retError) }() hCtx := newHandlerContext( ctx, request.GetDomain(), nil, h.metricsClient, metrics.MatchingGetTaskListsByDomainScope, h.logger, ) sw := hCtx.startProfiling(&h.startWG) defer sw.Stop() if ok := h.userRateLimiter.Allow(quotas.Info{Domain: request.GetDomain()}); !ok { return nil, hCtx.handleErr(errMatchingHostThrottle) } response, err := h.engine.GetTaskListsByDomain(hCtx, request) return response, hCtx.handleErr(err) } func (h *handlerImpl) domainName(id string) string { domainName, err := h.domainCache.GetDomainName(id) if err != nil { return "" } return domainName }