utils/dedup/request_cache.go (140 lines of code) (raw):
// Copyright (c) 2016-2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dedup
import (
"errors"
"sync"
"time"
"github.com/andres-erbsen/clock"
)
// RequestCacheConfig defines RequestCache configuration.
type RequestCacheConfig struct {
NotFoundTTL time.Duration `yaml:"not_found_ttl"`
ErrorTTL time.Duration `yaml:"error_ttl"`
CleanupInterval time.Duration `yaml:"cleanup_interval"`
NumWorkers int `yaml:"num_workers"`
BusyTimeout time.Duration `yaml:"busy_timeout"`
}
func (c *RequestCacheConfig) applyDefaults() {
// TODO(codyg): If the cached error TTL is lower than the interval in which
// clients are polling a 202 endpoint, then it is possible that the client
// will never hit the actual error because it expires in between requests.
if c.NotFoundTTL == 0 {
c.NotFoundTTL = 15 * time.Second
}
if c.ErrorTTL == 0 {
c.ErrorTTL = 15 * time.Second
}
if c.CleanupInterval == 0 {
c.CleanupInterval = 5 * time.Second
}
if c.NumWorkers == 0 {
c.NumWorkers = 10000
}
if c.BusyTimeout == 0 {
c.BusyTimeout = 5 * time.Second
}
}
// RequestCache errors.
var (
ErrRequestPending = errors.New("request pending")
ErrWorkersBusy = errors.New("no workers available to handle request")
)
type cachedError struct {
err error
expiresAt time.Time
}
func (e *cachedError) expired(now time.Time) bool {
return now.After(e.expiresAt)
}
// Request defines functions which encapsulate a request.
type Request func() error
// ErrorMatcher defines functions which RequestCache uses to detect user defined
// errors.
type ErrorMatcher func(error) bool
// RequestCache tracks pending requests and caches errors for configurable TTLs.
// It is used to prevent request duplication and DDOS-ing external components.
// Each request is represented by an arbitrary id string determined by the user.
type RequestCache struct {
config RequestCacheConfig
clk clock.Clock
mu sync.Mutex // Protects access to the following fields:
pending map[string]bool
errors map[string]*cachedError
lastClean time.Time
isNotFound ErrorMatcher
numWorkers chan struct{}
}
// NewRequestCache creates a new RequestCache.
func NewRequestCache(config RequestCacheConfig, clk clock.Clock) *RequestCache {
config.applyDefaults()
return &RequestCache{
config: config,
clk: clk,
pending: make(map[string]bool),
errors: make(map[string]*cachedError),
lastClean: clk.Now(),
isNotFound: func(error) bool { return false },
numWorkers: make(chan struct{}, config.NumWorkers),
}
}
// SetNotFound sets the ErrorMatcher for activating the configured NotFoundTTL
// for errors returned by Request functions.
func (c *RequestCache) SetNotFound(m ErrorMatcher) {
c.mu.Lock()
defer c.mu.Unlock()
c.isNotFound = m
}
// Start concurrently runs r under the given id. Any error returned by r will be
// cached for the configured TTL. If there is already a function executing under
// id, Start returns ErrRequestPending. If there are no available workers to run
// r, Start returns ErrWorkersBusy.
func (c *RequestCache) Start(id string, r Request) error {
if err := c.reserve(id); err != nil {
return err
}
if err := c.reserveWorker(); err != nil {
c.release(id)
return err
}
go func() {
defer c.releaseWorker()
c.run(id, r)
}()
return nil
}
func (c *RequestCache) reserve(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
// Periodically remove expired errors.
if c.clk.Now().Sub(c.lastClean) > c.config.CleanupInterval {
for id, cerr := range c.errors {
if cerr.expired(c.clk.Now()) {
delete(c.errors, id)
}
}
c.lastClean = c.clk.Now()
}
if c.pending[id] {
return ErrRequestPending
}
if cerr, ok := c.errors[id]; ok && !cerr.expired(c.clk.Now()) {
return cerr.err
}
c.pending[id] = true
return nil
}
func (c *RequestCache) run(id string, r Request) {
if err := r(); err != nil {
c.error(id, err)
return
}
c.release(id)
}
func (c *RequestCache) release(id string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.pending, id)
}
func (c *RequestCache) error(id string, err error) {
c.mu.Lock()
defer c.mu.Unlock()
var ttl time.Duration
if c.isNotFound(err) {
ttl = c.config.NotFoundTTL
} else {
ttl = c.config.ErrorTTL
}
delete(c.pending, id)
c.errors[id] = &cachedError{err: err, expiresAt: c.clk.Now().Add(ttl)}
}
func (c *RequestCache) reserveWorker() error {
select {
case c.numWorkers <- struct{}{}:
return nil
case <-c.clk.After(c.config.BusyTimeout):
return ErrWorkersBusy
}
}
func (c *RequestCache) releaseWorker() {
<-c.numWorkers
}