internal/credsretriever/refreshing_cache.go (222 lines of code) (raw):
package credsretriever
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.amzn.com/eks/eks-pod-identity-agent/internal/cache/expiring"
"go.amzn.com/eks/eks-pod-identity-agent/internal/cloud/eksauth"
"go.amzn.com/eks/eks-pod-identity-agent/internal/middleware/logger"
"go.amzn.com/eks/eks-pod-identity-agent/pkg/credentials"
"golang.org/x/time/rate"
)
type cachedCredentialRetriever struct {
// internalCache is where credentials are stored, it runs a janitor that evicts and refreshes
// entries once they expire. Key in the cache is the service token, values
// are of type cacheEntry.
internalCache *expiring.Cache[string, cacheEntry]
// internalActiveRequestCache tracks the active ongoing requests. Key in the cache is the service
// token, values are errors returned from the active requests. When a service token is in the
// internalActiveRequestCache, but not internalCache, it means an active request is ongoing,
// other requests to the same service token should wait for this active request.
internalActiveRequestCache *expiring.Cache[string, error]
// delegate is who we are actually getting the credentials from
delegate credentials.CredentialRetriever
// credentialsRenewalTtl the maximum amount of time that we can hold
// credentials in the cache
credentialsRenewalTtl time.Duration
// minCredentialTtl minimum amount of time credentials need to have in order
// to store them and consider them valid, default is 15s
minCredentialTtl time.Duration
// retryInterval is the least amount of time the cache will to wait to renew
// credentials, default is 1m
retryInterval time.Duration
// maxRetryJitter is the maximum amount jitter time we can add when credentials
// are scheduled for renewal
maxRetryJitter time.Duration
// now internal clock that is used to get time. Usefull for testing
// purposes
now internalClock
// refreshRateLimiter slows down refreshes to avoid getting throttled by EKS Auth
// in case there is some sort of backlog of creds waiting to be refreshed
refreshRateLimiter *rate.Limiter
}
type cacheEntry struct {
requestLogCtx context.Context
originatingRequest *credentials.EksCredentialsRequest
credentials *credentials.EksCredentialsResponse
}
// internalClock is used to get the current time
type internalClock func() time.Time
// type assertion
var _ credentials.CredentialRetriever = &cachedCredentialRetriever{}
var (
promCacheError = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "pod_identity_cache_errors",
Help: "Removing credentials from cache, got non recoverable error",
}, []string{"type"},
)
promCacheState = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "pod_identity_cache_state",
Help: "The state of credential in cache",
}, []string{"state"},
)
)
const (
defaultActiveRequestRetries = 9
defaultActiveRequestWaitTime = 200 * time.Millisecond
// defaultCleanupInterval sets how often we go over the cache to check if
// there are expired credentials requiring renewal
defaultCleanupInterval = 1 * time.Minute
defaultMinCredentialTtl = 15 * time.Second
defaultRetryInterval = 1 * time.Minute
defaultMaxRetryJitter = 1 * time.Minute
renewalTimeout = 1 * time.Minute
)
type CachedCredentialRetrieverOpts struct {
Delegate credentials.CredentialRetriever
CredentialsRenewalTtl time.Duration
MaxCacheSize int
RefreshQPS int
CleanupInterval time.Duration
}
// NewCachedCredentialRetriever creates a credential retriever that caches
// credentials up to min(credentialsRenewalTtl, fetchedCredentialExpiration)
// It renews credentials indefinitely until the association is removed and
// no longer needed
func NewCachedCredentialRetriever(opts CachedCredentialRetrieverOpts) credentials.CredentialRetriever {
if opts.Delegate == nil {
panic("Delegate is not allowed to be empty")
}
if opts.CleanupInterval <= 0 {
opts.CleanupInterval = defaultCleanupInterval
}
if opts.RefreshQPS <= 0 {
opts.RefreshQPS = 3
}
if opts.RefreshQPS*int(opts.CredentialsRenewalTtl.Seconds()) < opts.MaxCacheSize/2 {
panic(fmt.Sprintf(
"Refresh QPS is too small (%d) or credentials renewal to small (%0.2fs) to keep up with cache's size (%d)",
opts.RefreshQPS, opts.CredentialsRenewalTtl.Seconds(), opts.MaxCacheSize))
}
return newCachedCredentialRetriever(opts)
}
func newCachedCredentialRetriever(opts CachedCredentialRetrieverOpts) *cachedCredentialRetriever {
internalCache := expiring.NewLru[string, cacheEntry](opts.MaxCacheSize, opts.CredentialsRenewalTtl, opts.CleanupInterval)
internalActiveRequestCache := expiring.NewLru[string, error](opts.MaxCacheSize, 0, 0)
retriever := &cachedCredentialRetriever{
delegate: opts.Delegate,
internalCache: internalCache,
internalActiveRequestCache: internalActiveRequestCache,
credentialsRenewalTtl: opts.CredentialsRenewalTtl,
minCredentialTtl: defaultMinCredentialTtl,
retryInterval: defaultRetryInterval,
maxRetryJitter: defaultMaxRetryJitter,
now: time.Now,
refreshRateLimiter: rate.NewLimiter(rate.Limit(opts.RefreshQPS), opts.RefreshQPS),
}
internalCache.OnRefresh(retriever.onCredentialRenewal)
internalCache.OnEvicted(retriever.onCredentialEviction)
return retriever
}
// GetIamCredentials fetches credentials from the cache if available
func (r *cachedCredentialRetriever) GetIamCredentials(ctx context.Context,
request *credentials.EksCredentialsRequest) (*credentials.EksCredentialsResponse, credentials.ResponseMetadata, error) {
log := logger.FromContext(ctx)
if request == nil {
return nil, nil, fmt.Errorf("request to fetch credentials is empty, this is most likely a bug")
}
if request.ServiceAccountToken == "" {
return nil, nil, fmt.Errorf("service account is empty, cannot fetch credentials without a valid one")
}
for i := 0; i <= defaultActiveRequestRetries; i++ {
// Check if the request is in the cache, if it is, return it
if val, ok := r.internalCache.Get(request.ServiceAccountToken); ok {
if _, withinTtl := r.credentialsInEntryWithinValidTtl(val); withinTtl {
log.WithField("cache-hit", 1).Tracef("Using cached credentials")
return val.credentials, nil, nil
}
log.Info("Identified that entry in cache contains credentials with small ttl or invalid ttl, will be deleted")
r.internalCache.Delete(request.ServiceAccountToken)
break
}
if _, ok := r.internalActiveRequestCache.Get(request.ServiceAccountToken); !ok {
// No active request, exit the loop to fetch from delegate
break
} else {
if i > 0 {
log.Infof("Waiting for active request with %v tries", i)
}
// Wait for active request to finish caching into internalCache, if not the last retry
if i < defaultActiveRequestRetries {
time.Sleep(defaultActiveRequestWaitTime)
}
}
}
if _, ok := r.internalActiveRequestCache.Get(request.ServiceAccountToken); ok {
log.Warnf("Failed to complete active request in %v tries", defaultActiveRequestRetries)
}
r.internalActiveRequestCache.Add(request.ServiceAccountToken, nil)
defer r.internalActiveRequestCache.Delete(request.ServiceAccountToken)
log.WithField("cache-hit", 0).Tracef("Could not find entry in cache, requesting creds from delegate")
iamCredentials, metadata, err := r.callDelegateAndCache(ctx, request)
if err != nil {
return nil, nil, err
}
return iamCredentials.credentials, metadata, nil
}
func (r *cachedCredentialRetriever) callDelegateAndCache(ctx context.Context,
request *credentials.EksCredentialsRequest) (cacheEntry, credentials.ResponseMetadata, error) {
log := logger.FromContext(ctx)
newCacheEntry, err := r.fetchCredentialsFromDelegate(ctx, request)
if err != nil {
return cacheEntry{}, nil, fmt.Errorf("error getting credentials to cache: %w", err)
}
credsDuration, credentialsValid := r.credentialsInEntryWithinValidTtl(newCacheEntry)
if !credentialsValid {
return cacheEntry{}, nil, fmt.Errorf("fetched credentials are expired or will expire within the next %0.2f seconds", credsDuration.Seconds())
}
refreshTtl := minDuration(credsDuration, r.credentialsRenewalTtl)
log.WithField("refreshTtl", refreshTtl).Infof("Storing creds in cache")
// Store credentials in cache if they are valid. It might be that
// the credentials might have been either removed or inserted by another
// thread, but it won't matter, we'll just upsert as the cache is thread safe
r.internalCache.SetWithRefreshExpire(request.ServiceAccountToken, newCacheEntry, refreshTtl, credsDuration)
return newCacheEntry, nil, nil
}
func (r *cachedCredentialRetriever) credentialsInEntryWithinValidTtl(newCacheEntry cacheEntry) (time.Duration, bool) {
credsDuration := newCacheEntry.credentials.Expiration.Time.Sub(r.now())
credentialsLessThanMinCredTtl := credsDuration > r.minCredentialTtl
return credsDuration, credentialsLessThanMinCredTtl
}
func (r *cachedCredentialRetriever) fetchCredentialsFromDelegate(ctx context.Context,
request *credentials.EksCredentialsRequest) (cacheEntry, error) {
iamCredentials, metadata, err := r.delegate.GetIamCredentials(ctx, request)
if err != nil {
return cacheEntry{}, err
}
requestLogCtx := logger.ContextWithField(logger.CloneToNewIfPresent(ctx, context.Background()),
"association-id", metadata.AssociationId())
return cacheEntry{
originatingRequest: request,
requestLogCtx: requestLogCtx,
credentials: iamCredentials,
}, nil
}
// onCredentialRenewal is called by the internalCache whenever it refreshed
// credentials from the cache.
func (r *cachedCredentialRetriever) onCredentialRenewal(key string, entry cacheEntry) {
ctx, cancel := context.WithTimeout(
logger.ContextWithField(entry.requestLogCtx, "from", "renewal-thread"), renewalTimeout)
defer cancel()
log := logger.FromContext(ctx)
if r.refreshRateLimiter.Allow() {
err := r.refreshRateLimiter.Wait(ctx)
if err != nil {
log.Errorf("Problem waiting, will schedule refresh to next sweep")
return
}
_, _, err = r.callDelegateAndCache(ctx, entry.originatingRequest)
if err == nil {
// if we retrieved the credentials successfully, exit we don't need to do anything else
promCacheState.WithLabelValues("hit").Inc()
return
}
if eksauth.IsIrrecoverableApiError(err) {
log.Infof("Removing credentials from cache, got non recoverable error: %s", err.Error())
promCacheError.WithLabelValues("NonRecoverable").Inc()
r.internalCache.Delete(entry.originatingRequest.ServiceAccountToken)
return
}
log.Infof("Could not renew, will try to keep existing creds. Error is recoverable: %s", err.Error())
} else {
log.Infof("Rate limited! Will try to keep creds locally")
}
// if there was an error, try to keep the old credentials in the agent if they haven't expired
oldCreds := entry.credentials
oldCredsDuration := oldCreds.Expiration.Time.Sub(r.now())
if oldCredsDuration > r.minCredentialTtl {
calculatedRetryInterval := r.retryInterval + time.Duration(rand.Int63n(int64(r.maxRetryJitter)))
newRefreshTtl := minDuration(oldCredsDuration, calculatedRetryInterval)
log.WithField("ttl", newRefreshTtl).
Infof("Credentials still valid for at least %0.2fs, keeping them will try again after ttl expires", oldCredsDuration.Seconds())
r.internalCache.SetWithRefreshExpire(key, entry, newRefreshTtl, oldCredsDuration)
} else {
promCacheState.WithLabelValues("evicted").Inc()
log.Infof("Evicting credentials since they are too old")
}
}
func (r *cachedCredentialRetriever) onCredentialEviction(key string, entry cacheEntry) {
log := logger.FromContext(entry.requestLogCtx)
log.Infof("Credentials evicted")
}
func minDuration(a time.Duration, b time.Duration) time.Duration {
if a > b {
return b
} else {
return a
}
}