internal/grpc/middleware/limithandler/middleware.go (194 lines of code) (raw):
package limithandler
import (
"context"
"fmt"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/requestinfohandler"
"gitlab.com/gitlab-org/gitaly/v16/internal/limiter"
"google.golang.org/grpc"
)
// GetLockKey function defines the lock key of an RPC invocation based on its context
type GetLockKey func(context.Context) string
// LimitConcurrencyByRepo implements GetLockKey by using the repository path as lock.
func LimitConcurrencyByRepo(ctx context.Context) string {
if info := requestinfohandler.Extract(ctx); info != nil {
return info.Repository.GetRelativePath()
}
return ""
}
// LimiterMiddleware contains rate limiter state
type LimiterMiddleware struct {
methodLimiters map[string]limiter.Limiter
getLockKey GetLockKey
requestsDroppedMetric *prometheus.CounterVec
collect func(metrics chan<- prometheus.Metric)
}
// New creates a new middleware that limits requests. SetupFunc sets up the
// middleware with a specific kind of limiter.
func New(cfg config.Cfg, getLockKey GetLockKey, setupMiddleware SetupFunc) *LimiterMiddleware {
middleware := &LimiterMiddleware{
getLockKey: getLockKey,
requestsDroppedMetric: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "gitaly_requests_dropped_total",
Help: "Number of requests dropped from the queue",
},
[]string{
"system",
"grpc_service",
"grpc_method",
"reason",
},
),
}
setupMiddleware(cfg, middleware)
return middleware
}
// Describe is used to describe Prometheus metrics.
func (c *LimiterMiddleware) Describe(descs chan<- *prometheus.Desc) {
prometheus.DescribeByCollect(c, descs)
}
// Collect is used to collect Prometheus metrics.
func (c *LimiterMiddleware) Collect(metrics chan<- prometheus.Metric) {
c.requestsDroppedMetric.Collect(metrics)
if c.collect != nil {
c.collect(metrics)
}
}
// UnaryInterceptor returns a Unary Interceptor
func (c *LimiterMiddleware) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
lockKey := c.getLockKey(ctx)
if lockKey == "" {
return handler(ctx, req)
}
limiter := c.methodLimiters[info.FullMethod]
if limiter == nil {
// No concurrency limiting
return handler(ctx, req)
}
return limiter.Limit(ctx, lockKey, func() (interface{}, error) {
return handler(ctx, req)
})
}
}
// StreamInterceptor returns a Stream Interceptor
func (c *LimiterMiddleware) StreamInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &wrappedStream{stream, info, c, true}
return handler(srv, wrapper)
}
}
// SetupFunc set up a middleware to limiting requests
type SetupFunc func(cfg config.Cfg, middleware *LimiterMiddleware)
type wrappedStream struct {
grpc.ServerStream
info *grpc.StreamServerInfo
limiterMiddleware *LimiterMiddleware
initial bool
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
return err
}
// Only perform limiting on the first request of a stream
if !w.initial {
return nil
}
w.initial = false
ctx := w.Context()
lockKey := w.limiterMiddleware.getLockKey(ctx)
if lockKey == "" {
return nil
}
limiter := w.limiterMiddleware.methodLimiters[w.info.FullMethod]
if limiter == nil {
// No concurrency limiting
return nil
}
ready := make(chan struct{})
errs := make(chan error, 1)
go func() {
if _, err := limiter.Limit(ctx, lockKey, func() (interface{}, error) {
close(ready)
<-ctx.Done()
return nil, nil
}); err != nil {
errs <- err
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-ready:
// It's our turn!
return nil
case err := <-errs:
return err
}
}
// WithConcurrencyLimiters sets up middleware to limit the concurrency of
// requests based on RPC and repository
func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, SetupFunc) {
perRPCLimits := map[string]*limiter.AdaptiveLimit{}
for _, concurrency := range cfg.Concurrency {
limitName := fmt.Sprintf("perRPC%s", concurrency.RPC)
if concurrency.Adaptive {
perRPCLimits[concurrency.RPC] = limiter.NewAdaptiveLimit(limitName, limiter.AdaptiveSetting{
Initial: concurrency.InitialLimit,
Max: concurrency.MaxLimit,
Min: concurrency.MinLimit,
BackoffFactor: limiter.DefaultBackoffFactor,
})
} else {
perRPCLimits[concurrency.RPC] = limiter.NewAdaptiveLimit(limitName, limiter.AdaptiveSetting{
Initial: concurrency.MaxPerRepo,
})
}
}
return perRPCLimits, func(cfg config.Cfg, middleware *LimiterMiddleware) {
acquiringSecondsMetric := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "gitaly",
Subsystem: "concurrency_limiting",
Name: "acquiring_seconds",
Help: "Histogram of time calls are rate limited (in seconds)",
Buckets: cfg.Prometheus.GRPCLatencyBuckets,
},
[]string{"system", "grpc_service", "grpc_method"},
)
inProgressMetric := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "gitaly",
Subsystem: "concurrency_limiting",
Name: "in_progress",
Help: "Gauge of number of concurrent in-progress calls",
},
[]string{"system", "grpc_service", "grpc_method"},
)
queuedMetric := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "gitaly",
Subsystem: "concurrency_limiting",
Name: "queued",
Help: "Gauge of number of queued calls",
},
[]string{"system", "grpc_service", "grpc_method"},
)
middleware.collect = func(metrics chan<- prometheus.Metric) {
acquiringSecondsMetric.Collect(metrics)
inProgressMetric.Collect(metrics)
queuedMetric.Collect(metrics)
}
result := make(map[string]limiter.Limiter)
for _, concurrency := range cfg.Concurrency {
result[concurrency.RPC] = limiter.NewConcurrencyLimiter(
perRPCLimits[concurrency.RPC],
concurrency.MaxQueueSize,
concurrency.MaxQueueWait.Duration(),
limiter.NewPerRPCPromMonitor(
"gitaly", concurrency.RPC,
queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric,
),
)
}
// Set default for ReplicateRepository.
replicateRepositoryFullMethod := "/gitaly.RepositoryService/ReplicateRepository"
if _, ok := result[replicateRepositoryFullMethod]; !ok {
result[replicateRepositoryFullMethod] = limiter.NewConcurrencyLimiter(
limiter.NewAdaptiveLimit("staticLimit", limiter.AdaptiveSetting{Initial: 1}),
0,
0,
limiter.NewPerRPCPromMonitor(
"gitaly", replicateRepositoryFullMethod,
queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric,
),
)
}
middleware.methodLimiters = result
}
}