registry/handlers/ratelimiter.go (295 lines of code) (raw):

package handlers import ( "context" "encoding/base64" "errors" "fmt" "net" "net/http" "net/netip" "slices" "sort" "strings" "time" "github.com/docker/distribution/configuration" dcontext "github.com/docker/distribution/context" "github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/ratelimiter" "github.com/hashicorp/go-multierror" "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" ) const ( componentKey = "component" componentName = "registry.rate_limiter" matchTypeIP = "ip" matchTypeIPKey = `registry:api:{rate-limit:ip:%s}` headerXRateLimitRemaining = "X-RateLimit-Remaining" headerXRateLimitWarning = "X-RateLimit-Warning" headerXRateLimitPercentage = "X-RateLimit-Percentage" headerRetryAfter = "Retry-After" IPV6PrefixLength = 64 ) var ( validMatchTypes = []string{matchTypeIP} validWarnActions = []string{"none", "log"} validHardActions = []string{"none", "log", "block"} validPeriods = []string{"second", "minute", "hour"} ) func (app *App) configureRateLimiters(redisClient redis.UniversalClient, config *configuration.RateLimiter) error { l := dcontext.GetLogger(app.Context) orderedLimiters, err := parseLimitersConfig(config) if err != nil { return err } if len(orderedLimiters) == 0 { return nil } app.rateLimiters = make([]RateLimiter, 0, len(orderedLimiters)) for name, orderedLimiter := range orderedLimiters { cfg := orderedLimiter.Limiter limiter := ratelimiter.New(redisClient, cfg) l.WithFields(logrus.Fields{ "name": name, "description": cfg.Description, "log_only": cfg.LogOnly, "match_type": cfg.Match.Type, "rate": cfg.Limit.Rate, "burst": cfg.Limit.Burst, "period": cfg.Limit.Period, "warn_threshold": cfg.Action.WarnThreshold, "warn_action": cfg.Action.WarnAction, "hard_action": cfg.Action.HardAction, }).Info("Configured rate limiter") app.rateLimiters = append(app.rateLimiters, limiter) } return nil } // OrderedLimiter is a helper struct to sort limiters by precedence type OrderedLimiter struct { Name string Limiter *configuration.Limiter } func parseLimitersConfig(rateLimiterCfg *configuration.RateLimiter) ([]OrderedLimiter, error) { if !rateLimiterCfg.Enabled { return nil, nil } limiters := make(map[string]*configuration.Limiter) keys := make([]string, 0, len(rateLimiterCfg.Limiters)) mError := new(multierror.Error) for _, limiterConfig := range rateLimiterCfg.Limiters { err := validateLimiter(&limiterConfig) if err != nil { mError = multierror.Append(mError, err) continue } keys = append(keys, limiterConfig.Name) limiters[limiterConfig.Name] = &limiterConfig } if len(mError.Errors) > 0 { return nil, mError.ErrorOrNil() } // sort limiters by precedence in ascending order sort.Slice(keys, func(i, j int) bool { return limiters[keys[i]].Precedence < limiters[keys[j]].Precedence }) // Create a slice of OrderedLimiter orderedLimiters := make([]OrderedLimiter, 0, len(keys)) for _, key := range keys { orderedLimiters = append(orderedLimiters, OrderedLimiter{ Name: key, Limiter: limiters[key], }) } return orderedLimiters, nil } func validateLimiter(c *configuration.Limiter) error { mError := new(multierror.Error) if c.Name == "" { mError = multierror.Append(mError, fmt.Errorf("limiter name cannot be empty")) } if c.Precedence <= 0 { mError = multierror.Append(mError, fmt.Errorf("limiter precedence must be a positive integer")) } c.Match.Type = strings.TrimSpace(strings.ToLower(c.Match.Type)) if !slices.Contains(validMatchTypes, c.Match.Type) { mError = multierror.Append(mError, fmt.Errorf("match.type must be one of: %+v", validMatchTypes)) } if c.Limit.Rate <= 0 { mError = multierror.Append(mError, fmt.Errorf("rate must be a positive integer")) } if c.Limit.Burst <= 0 { mError = multierror.Append(mError, fmt.Errorf("burst must be a positive integer")) } if !slices.Contains(validPeriods, c.Limit.Period) { mError = multierror.Append(mError, fmt.Errorf("period must be one of: %+v", validPeriods)) } else { switch c.Limit.Period { case "second": c.Limit.PeriodDuration = time.Second case "minute": c.Limit.PeriodDuration = time.Minute case "hour": c.Limit.PeriodDuration = time.Hour default: c.Limit.PeriodDuration = time.Second } } if !slices.Contains(validWarnActions, c.Action.WarnAction) { mError = multierror.Append(mError, fmt.Errorf("action.warn_action must be one of: %+v", validWarnActions)) } if c.Action.WarnThreshold < 0.0 || c.Action.WarnThreshold > 1.0 { mError = multierror.Append(mError, fmt.Errorf("action.warn_threshold must be between 0.0 and 1.0")) } if !slices.Contains(validHardActions, c.Action.HardAction) { mError = multierror.Append(mError, fmt.Errorf("action.hard_action must be one of: %+v", validHardActions)) } return mError.ErrorOrNil() } // RateLimiter represents a rate limiter that can be used to control the rate of requests. type RateLimiter interface { // Allowed checks if a request is allowed based on the given key and limit. // Returns true if the request is allowed, false otherwise. Allowed(ctx context.Context, key string) (*ratelimiter.Result, error) Config() *configuration.Limiter } func (app *App) rateLimiterMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !app.Config.RateLimiter.Enabled || app.rateLimiters == nil { next.ServeHTTP(w, r) return } l := dcontext.GetLogger(r.Context()) ctx := app.context(w, r) // Process each limiter in order of precedence for _, limiter := range app.rateLimiters { blocked := processLimiter(ctx, w, r, limiter, l) if blocked { return // Request was blocked, don't continue to next limiter or handler } } next.ServeHTTP(w, r) }) } // processLimiter handles a single rate limiter for a request // Returns true if the request was blocked func processLimiter(ctx *Context, w http.ResponseWriter, r *http.Request, limiter RateLimiter, l dcontext.Logger) bool { cfg := limiter.Config() // Get the rate limit key based on match type key, ok := getRateLimitKey(r, cfg.Match.Type, l) if !ok { return false // Skip this limiter } // Check if request is allowed by this limiter result, err := limiter.Allowed(r.Context(), key) if err != nil { serveErrorJSON(w, err, ctx, l) return true // Block the request on error } // Always set basic rate limit headers w.Header().Set(headerXRateLimitRemaining, fmt.Sprintf("%d", result.Remaining)) // Check if retry after should be set if result.RetryAfter.Seconds() > 0 { w.Header().Set(headerRetryAfter, fmt.Sprintf("%f", result.RetryAfter.Seconds())) } // Check warning threshold thresholdReached, usagePercentage := checkWarningThreshold(result, cfg, l) if thresholdReached { // Set warning headers when threshold is reached w.Header().Set(headerXRateLimitWarning, "true") w.Header().Set(headerXRateLimitPercentage, fmt.Sprintf("%0.2f", usagePercentage*100)) } // Check if rate limit exceeded if result.Allowed <= 0 { logRateLimitedRequest(ctx, r, "request blocked: rate limit exceeded", cfg, result) if !cfg.LogOnly && cfg.Action.HardAction == "block" { blockRateLimitedRequest(w, r, result, cfg.Match.Type, ctx, l) return true // Request blocked } } return false // Continue to next limiter } // getRateLimitKey determines the key to use for rate limiting based on match type func getRateLimitKey(r *http.Request, matchType string, l dcontext.Logger) (string, bool) { switch matchType { case matchTypeIP: return fmt.Sprintf(matchTypeIPKey, encodeIPBase64(GetIPV4orIPV6Prefix(r.RemoteAddr))), true default: l.Warnf("rate_limiter unsupported match type: %s, skipping", matchType) return "", false } } // checkWarningThreshold checks if warning threshold is reached and logs appropriately // Returns true if the warning threshold was reached func checkWarningThreshold(result *ratelimiter.Result, cfg *configuration.Limiter, l dcontext.Logger) (bool, float64) { warnThreshold := cfg.Action.WarnThreshold // Special case for threshold 0.0 - no warnings needed if warnThreshold <= 0 { // Don't log warnings when threshold is 0, but still return true // for rate-limited responses so headers get set return result.Allowed <= 0, 0 } // Calculate usage percentage based on the GCRA algorithm behavior var usagePercentage float64 switch { case result.Allowed <= 0: // Already over limit usagePercentage = 1.0 case result.Allowed > 0 && result.Remaining <= 0: // When Remaining is 0 but we're still allowed (using burst) // Calculate how much of the burst we've used based on allowed tokens burstCapacity := float64(cfg.Limit.Burst) regularCapacity := float64(cfg.Limit.Rate) totalCapacity := regularCapacity + burstCapacity // At this point regularCapacity is fully used, and we're using some of burstCapacity // The amount of burst we're using is represented by result.Allowed burstUsed := burstCapacity - float64(result.Allowed) totalUsed := regularCapacity + burstUsed usagePercentage = totalUsed / totalCapacity default: // Normal case - remaining > 0 // Higher values mean more usage usagePercentage = 1.0 - (float64(result.Remaining) / float64(cfg.Limit.Burst)) } if usagePercentage >= warnThreshold { logger := l.WithFields(logrus.Fields{ "name": cfg.Name, "description": cfg.Description, "warn_threshold": warnThreshold, "usage": usagePercentage, }) switch cfg.Action.WarnAction { case "log": logger.Warn("rate_limiter reached threshold") case "none": fallthrough default: logger.Debug("rate_limiter reached threshold but no action will be taken") } return true, usagePercentage } return false, usagePercentage } // blockRateLimitedRequest handles blocking a request that exceeded rate limits func blockRateLimitedRequest(w http.ResponseWriter, r *http.Request, result *ratelimiter.Result, matchType string, ctx *Context, l dcontext.Logger) { w.Header().Set(headerXRateLimitRemaining, fmt.Sprintf("%d", result.Remaining)) w.Header().Set(headerRetryAfter, fmt.Sprintf("%f", result.RetryAfter.Seconds())) detail := map[string]string{ "ip": GetIPV4orIPV6Prefix(r.RemoteAddr), "limit": matchType, "retry_after": result.RetryAfter.String(), "remaining": fmt.Sprintf("%d", result.Remaining), } serveErrorJSON(w, errcode.ErrorCodeTooManyRequests.WithDetail(detail), ctx, l) } // serveErrorJSON handles serving an error response as JSON func serveErrorJSON(w http.ResponseWriter, err error, ctx *Context, l dcontext.Logger) { var errorToServe errcode.Error if !errors.As(err, &errorToServe) { errorToServe = errcode.FromUnknownError(err) } if err := errcode.ServeJSON(w, errorToServe); err != nil { l.Errorf("error serving error json: %v (from %v)", err, ctx.Errors) } } func logRateLimitedRequest(ctx *Context, r *http.Request, msg string, limiterConfig *configuration.Limiter, result *ratelimiter.Result) { dcontext.GetLogger(ctx).WithFields(logrus.Fields{ componentKey: componentName, "name": limiterConfig.Name, "log_only": limiterConfig.LogOnly, "method": r.Method, "path": r.URL.Path, // Using path instead of full URL to reduce log size "source_ip": GetIPV4orIPV6Prefix(r.RemoteAddr), "retry_after": result.RetryAfter, // Essential for understanding when rate limit resets "action": limiterConfig.Action.HardAction, // Important to know if blocking or just logging }).Info(msg) } // GetIPV4orIPV6Prefix returns either the full IPv4 address or the /64 prefix // of the IPv6 address from the provided remote address, without the port. // For IPv4 it returns the full address. For IPv6 it returns the /64 prefix. func GetIPV4orIPV6Prefix(remoteAddr string) string { remoteIP, _, err := net.SplitHostPort(remoteAddr) if err != nil { remoteIP = remoteAddr } addr, err := netip.ParseAddr(remoteIP) if err != nil { return remoteIP } if addr.Is4() { return remoteIP } else if addr.Is6() { ipv6Prefix, err := addr.Prefix(IPV6PrefixLength) if err != nil { return remoteIP } return ipv6Prefix.String() } return remoteIP } func encodeIPBase64(ip string) string { return base64.StdEncoding.EncodeToString([]byte(ip)) }