internal/conn/storage/credcache.go (146 lines of code) (raw):

package storage import ( "context" "errors" "fmt" "log/slog" "sync/atomic" "time" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" "github.com/Azure/retry/exponential" ) // credData is the data stored in the credCache. type credData struct { cred *service.UserDelegationCredential expires time.Time err error } // getCred is an interface for getting a user delegation credential. I am violating go naming // conventions because GetUserDelegationCredential is already violating it, and I want a shorter name. // This is implemented by *service.Client. type getCreder interface { GetUserDelegationCredential(ctx context.Context, info service.KeyInfo, o *service.GetUserDelegationCredentialOptions) (*service.UserDelegationCredential, error) } // credCache is a cache for user delegation credentials. It is non-blocking and updates // the credential in the background. type credCache struct { now func() time.Time cli getCreder cred atomic.Pointer[credData] log *slog.Logger closeCh chan struct{} fakeRefreshCred func(ctx context.Context, now time.Time) error start bool } type ccOption func(*credCache) error // withLogger sets the logger on the credCache. func withLogger(log *slog.Logger) ccOption { return func(c *credCache) error { c.log = log return nil } } // newCredCache creates a new credCache. func newCredCache(client getCreder, options ...ccOption) (*credCache, error) { cc := &credCache{ now: time.Now, cli: client, log: slog.Default(), closeCh: make(chan struct{}), start: true, } for _, o := range options { if err := o(cc); err != nil { return nil, err } } if cc.start { if err := cc.refreshCred(context.Background(), cc.now().UTC()); err != nil { return nil, fmt.Errorf("credCache: problem getting credential: %w", err) } go cc.refresher() } return cc, nil } // close closes the credCache. func (c *credCache) close() { close(c.closeCh) } // get gets the user delegation credential from the cache. If the credential is expired you will receive an // error. This only occurs if the background goroutine fails to refresh the credential. // In that case, this will return the last error received. func (c *credCache) get(ctx context.Context) (*service.UserDelegationCredential, error) { cred := c.cred.Load() if cred == nil { return nil, errors.New("no credential") } if cred.expires.Before(c.now()) { if cred.err != nil { return nil, cred.err } return nil, errors.New("credential expired") } return cred.cred, nil } // refresher is a background goroutine that refreshes the user delegation credential. func (c *credCache) refresher() { const ( nextRefresh = 23 * time.Hour ) boff, err := exponential.New() if err != nil { // We aren't passing any options, this should never happen and // we should panic if it does. panic(err) } ctx := context.Background() for { next := time.Now().Add(nextRefresh) // This will block until the next refresh time. // An error will only be returned if the cache is closed, so it can be ignored. if err := c.refresh(ctx, boff, next); err != nil { return } } } // refresh refreshes the user delegation credential after a next time. // It will retry forever with a backoff policy or until close() is called. // Only returns an error if closed is called, which can be ignored. func (c *credCache) refresh(ctx context.Context, boff *exponential.Backoff, next time.Time) error { select { case <-c.closeCh: return errors.New("closed") case <-time.After(next.Sub(c.now())): // This will retry forever. On any failures it will log the error and continue. // Every retry can only take up to 30 seconds. Uses the default policy which has a // maximum time of 1min - 1min30s between attempts. err := boff.Retry(ctx, func(ctx context.Context, r exponential.Record) error { select { case <-c.closeCh: return fmt.Errorf("credCache closed: %w", exponential.ErrPermanent) default: } ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() if err := c.refreshCred(ctx, c.now().UTC()); err != nil { c.log.Error(fmt.Sprintf("credCache: problem refreshing credential: %s", err.Error())) return err } return nil }) return err } } // refresh refreshes the user delegation credential. func (c *credCache) refreshCred(ctx context.Context, now time.Time) error { if c.fakeRefreshCred != nil { return c.fakeRefreshCred(ctx, now) } start := now.Truncate(time.Second) expiry := start.Add(7 * 24 * time.Hour) cred, err := c.cli.GetUserDelegationCredential( ctx, service.KeyInfo{ Expiry: toPtr(expiry.UTC().Format(sas.TimeFormat)), Start: toPtr(start.UTC().Format(sas.TimeFormat)), }, nil, ) if err != nil { current := c.cred.Load() if current == nil { current = &credData{} } if current.expires.Before(c.now()) { cd := &credData{ expires: current.expires, err: err, } c.cred.Store(cd) } return err } cd := &credData{ cred: cred, expires: expiry, } c.cred.Store(cd) return nil }