internal/ratelimiter/ratelimiter.go (149 lines of code) (raw):
package ratelimiter
import (
"crypto/tls"
"net"
"net/http"
"net/netip"
"time"
"github.com/3th1nk/cidr"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
"gitlab.com/gitlab-org/gitlab-pages/internal/lru"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
const (
// based on an avg ~4,000 unique IPs per minute
// https://log.gprd.gitlab.net/app/lens#/edit/f7110d00-2013-11ec-8c8e-ed83b5469915?_g=h@e78830b
DefaultSourceIPCacheSize = 5000
// we have less than 4000 different hosts per minute
// https://log.gprd.gitlab.net/app/dashboards#/view/d52ab740-61a4-11ec-b20d-65f14d890d9b?_a=(viewMode:edit)&_g=h@42b0d52
DefaultDomainCacheSize = 4000
)
// Option function to configure a RateLimiter
type Option func(*RateLimiter)
// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
type KeyFunc func(*http.Request) string
// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername)
type TLSKeyFunc func(*tls.ClientHelloInfo) string
// RateLimiter holds an LRU cache of elements to be rate limited.
// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
// It also holds a now function that can be mocked in unit tests.
type RateLimiter struct {
name string
now func() time.Time
keyFunc KeyFunc
tlsKeyFunc TLSKeyFunc
limitPerSecond float64
burstSize int
blockedCount *prometheus.GaugeVec
cache *lru.Cache
cacheOptions []lru.Option
closeConnection bool
bypassCIDRs []cidr.CIDR
}
// New creates a new RateLimiter with default values that can be configured via Option functions
func New(name string, opts ...Option) *RateLimiter {
rl := &RateLimiter{
name: name,
now: time.Now,
keyFunc: request.GetIPV4orIPV6PrefixWithoutPort,
}
for _, opt := range opts {
opt(rl)
}
if rl.limitPerSecond > 0.0 {
rl.cache = lru.New(name, rl.cacheOptions...)
}
return rl
}
// Stop stops the RateLimiter and releases any resources associated with it.
// It should be called when the RateLimiter is no longer needed to ensure proper cleanup.
// This is used by tests
func (rl *RateLimiter) Stop() {
if rl.cache != nil {
rl.cache.Stop()
}
}
// WithNow replaces the RateLimiter now function
func WithNow(now func() time.Time) Option {
return func(rl *RateLimiter) {
rl.now = now
}
}
// WithLimitPerSecond allows configuring limit per second for RateLimiter
func WithLimitPerSecond(limit float64) Option {
return func(rl *RateLimiter) {
rl.limitPerSecond = limit
}
}
// WithBurstSize configures burst per keyFunc value for the RateLimiter
func WithBurstSize(burst int) Option {
return func(rl *RateLimiter) {
rl.burstSize = burst
}
}
// WithBlockedCountMetric configures metric reporting how many requests were blocked
func WithBlockedCountMetric(m *prometheus.GaugeVec) Option {
return func(rl *RateLimiter) {
rl.blockedCount = m
}
}
// WithCacheMaxSize configures cache size for ratelimiter
func WithCacheMaxSize(size int64) Option {
return func(rl *RateLimiter) {
rl.cacheOptions = append(rl.cacheOptions, lru.WithMaxSize(size))
}
}
// WithCachedEntriesMetric configures metric reporting how many keys are currently stored in
// the rate-limiter cache
func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option {
return func(rl *RateLimiter) {
rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedEntriesMetric(m))
}
}
// WithCachedRequestsMetric configures metric for how many times we access cache
func WithCachedRequestsMetric(m *prometheus.CounterVec) Option {
return func(rl *RateLimiter) {
rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m))
}
}
// WithKeyFunc configures keyFunc
func WithKeyFunc(f KeyFunc) Option {
return func(rl *RateLimiter) {
rl.keyFunc = f
}
}
// WithCloseConnection configures closeConnection bool whether to send connection close header or not
func WithCloseConnection(closeCon bool) Option {
return func(rl *RateLimiter) {
rl.closeConnection = closeCon
}
}
// WithBypassCIDRs configures the RateLimiter with bypass CIDRs
func WithBypassCIDRs(bypassCIDRs []cidr.CIDR) Option {
return func(rl *RateLimiter) {
if len(bypassCIDRs) > 0 {
rl.bypassCIDRs = bypassCIDRs
}
}
}
func TLSHostnameKey(info *tls.ClientHelloInfo) string {
return info.ServerName
}
func TLSClientIPKey(info *tls.ClientHelloInfo) string {
remoteAddr := info.Conn.RemoteAddr().String()
remoteAddr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return request.GetIPV4orIPV6Prefix(remoteAddr)
}
func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option {
return func(rl *RateLimiter) {
rl.tlsKeyFunc = keyFunc
}
}
func (rl *RateLimiter) limiter(key string) *rate.Limiter {
limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) {
return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil
})
return limiterI.(*rate.Limiter)
}
func (rl *RateLimiter) shouldBypassRateLimit(key string) bool {
// key is either the full IPv4 address or the /64 prefix of the IPv6 address
prefix, err := netip.ParsePrefix(key)
if err == nil {
key = prefix.Addr().String()
}
for _, c := range rl.bypassCIDRs {
if c.Contains(key) {
return true
}
}
return false
}
// requestAllowed checks if request is within the rate-limit
func (rl *RateLimiter) requestAllowed(r *http.Request) bool {
rateLimitedKey := rl.keyFunc(r)
return rl.allowed(rateLimitedKey, request.GetIPV4orIPV6PrefixWithoutPort(r))
}
func (rl *RateLimiter) allowed(rateLimitedKey string, sourceIP string) bool {
limiter := rl.limiter(rateLimitedKey)
if rl.shouldBypassRateLimit(sourceIP) {
return true
}
// AllowN allows us to use the rl.now function, so we can test this more easily.
return limiter.AllowN(rl.now(), 1)
}