oss/credentials/fetcher_credentials_provider.go (146 lines of code) (raw):
package credentials
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
var (
// Default expiration time adjustment factor
defaultExpiredFactor = 0.8
// backoff of refresh time
defaultRefreshDuration = 120 * time.Second
)
// CredentialsFetcherOptions are the options
type CredentialsFetcherOptions struct {
ExpiredFactor float64
RefreshDuration time.Duration
}
type CredentialsFetcher interface {
Fetch(ctx context.Context) (Credentials, error)
}
// CredentialsFetcherFunc provides a helper wrapping a function value to
// satisfy the CredentialsFetcher interface.
type CredentialsFetcherFunc func(context.Context) (Credentials, error)
// Fetch delegates to the function value the CredentialsFetcherFunc wraps.
func (fn CredentialsFetcherFunc) Fetch(ctx context.Context) (Credentials, error) {
return fn(ctx)
}
type CredentialsFetcherProvider struct {
m sync.Mutex
//credentials *fetcherCredentials
credentials atomic.Value
fetcher CredentialsFetcher
expiredFactor float64
refreshDuration time.Duration
}
type fetcherCredentials struct {
Creds Credentials
ExpiryWindow time.Duration
}
func NewCredentialsFetcherProvider(fetcher CredentialsFetcher, optFns ...func(*CredentialsFetcherOptions)) CredentialsProvider {
options := CredentialsFetcherOptions{
ExpiredFactor: defaultExpiredFactor,
RefreshDuration: defaultRefreshDuration,
}
for _, fn := range optFns {
fn(&options)
}
return &CredentialsFetcherProvider{
fetcher: fetcher,
expiredFactor: options.ExpiredFactor,
refreshDuration: options.RefreshDuration,
}
}
func (c *CredentialsFetcherProvider) GetCredentials(ctx context.Context) (Credentials, error) {
fcreds := c.getCreds()
if c.isExpired(fcreds) {
c.m.Lock()
defer c.m.Unlock()
creds, err := c.fetch(ctx)
if err == nil {
c.updateCreds(&creds)
}
return creds, err
} else {
if c.isSoonExpire(fcreds) && c.m.TryLock() {
defer c.m.Unlock()
fcreds1 := c.getCreds()
if fcreds1 == fcreds {
creds, err := c.fetch(ctx)
if err == nil {
c.updateCreds(&creds)
return creds, nil
} else {
c.updateExpiryWindow(fcreds1)
err = nil
}
}
fcreds = fcreds1
}
return fcreds.Creds, nil
}
}
type asyncFetchResult struct {
val Credentials
err error
}
func (c *CredentialsFetcherProvider) asyncFetch(ctx context.Context) <-chan asyncFetchResult {
doChan := func() <-chan asyncFetchResult {
ch := make(chan asyncFetchResult, 1)
go func() {
cred, err := c.fetcher.Fetch(ctx)
ch <- asyncFetchResult{cred, err}
}()
return ch
}
return doChan()
}
func (c *CredentialsFetcherProvider) fetch(ctx context.Context) (Credentials, error) {
if c.fetcher == nil {
return Credentials{}, fmt.Errorf("fetcher is null.")
}
select {
case result, _ := <-c.asyncFetch(ctx):
return result.val, result.err
case <-ctx.Done():
return Credentials{}, fmt.Errorf("FetchCredentialsCanceled")
}
}
func (p *CredentialsFetcherProvider) getCreds() *fetcherCredentials {
v := p.credentials.Load()
if v == nil {
return nil
}
creds, _ := v.(*fetcherCredentials)
return creds
}
func (c *CredentialsFetcherProvider) updateCreds(cred *Credentials) {
fcred := fetcherCredentials{
Creds: *cred,
}
if cred.Expires != nil {
curr := time.Now().Round(0)
durationS := c.expiredFactor * float64(cred.Expires.Sub(curr).Seconds())
duration := time.Duration(durationS * float64(time.Second))
if duration > c.refreshDuration {
fcred.ExpiryWindow = duration
}
}
c.credentials.Store(&fcred)
}
func (c *CredentialsFetcherProvider) updateExpiryWindow(fcreds *fetcherCredentials) {
if fcreds.ExpiryWindow > 0 {
fcreds1 := *fcreds
fcreds1.ExpiryWindow -= c.refreshDuration
c.credentials.Store(&fcreds1)
}
}
func (c *CredentialsFetcherProvider) isExpired(fcreds *fetcherCredentials) bool {
return fcreds == nil || fcreds.Creds.Expired()
}
func (c *CredentialsFetcherProvider) isSoonExpire(fcreds *fetcherCredentials) bool {
if fcreds == nil || fcreds.Creds.Expired() {
return true
}
if fcreds.ExpiryWindow > 0 && fcreds.Creds.Expires != nil {
if !fcreds.Creds.Expires.After(time.Now().Round(0).Add(fcreds.ExpiryWindow)) {
return true
}
}
return false
}