e2etest/newe2e_oauth_cache.go (109 lines of code) (raw):
package e2etest
import (
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-storage-azcopy/v10/common"
"sync"
"time"
)
const (
AzureManagementResource = "https://management.core.windows.net/.default"
AzureStorageResource = "https://storage.azure.com/.default"
AzureDisksResource = "https://disk.azure.com/.default"
)
var PrimaryOAuthCache *OAuthCache
func SetupOAuthCache(a Asserter) {
var (
cred azcore.TokenCredential
err error
tenantId string
staticOAuth = GlobalConfig.E2EAuthConfig.StaticStgAcctInfo.StaticOAuth
)
// We don't consider workload identity in here because it's only used in a few tests
if GlobalConfig.E2EAuthConfig.SubscriptionLoginInfo.Environment == AzurePipeline {
tenantId = GlobalConfig.E2EAuthConfig.SubscriptionLoginInfo.DynamicOAuth.Workload.TenantId
cred, err = azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
TenantID: tenantId,
})
} else if useSpn, tenant, appId, secret := GlobalConfig.GetSPNOptions(); useSpn {
tenantId = tenant
cred, err = azidentity.NewClientSecretCredential(tenant, appId, secret, nil)
} else if staticOAuth.OAuthSource.CLIInherit {
tenantId = staticOAuth.TenantID
cred, err = azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
TenantID: tenantId,
})
} else if staticOAuth.OAuthSource.PSInherit {
tenantId = staticOAuth.TenantID
cred, err = common.NewPowershellContextCredential(&common.PowershellContextCredentialOptions{
TenantID: tenantId,
})
} else {
a.Log("OAuth Cache unconfigured.")
}
a.NoError("create credentials", err)
a.AssertNow("cred cannot be nil", Not{IsNil{}}, cred)
PrimaryOAuthCache = NewOAuthCache(cred, tenantId)
}
/*
The goal of the OAuthCache is to prevent getting rejected for an auth loop in the testing framework.
As such, we store all the AccessTokens in one place such that any portion of the application can request them
*/
type OAuthCache struct {
tc azcore.TokenCredential
tenant string
tokens map[string]*azcore.AccessToken
mut *sync.RWMutex
}
func NewOAuthCache(cred azcore.TokenCredential, tenant string) *OAuthCache {
return &OAuthCache{
tc: cred,
tenant: tenant,
tokens: make(map[string]*azcore.AccessToken),
mut: &sync.RWMutex{},
}
}
var OAuthCacheDisabledError = errors.New("the OAuth cache is currently disabled")
func (o *OAuthCache) GetAccessToken(scope string) (*AzCoreAccessToken, error) {
if o == nil {
return nil, OAuthCacheDisabledError
}
o.mut.RLock()
tok, ok := o.tokens[scope]
o.mut.RUnlock()
if !ok || time.Now().Add(time.Minute*3).After(tok.ExpiresOn) {
o.mut.Lock()
newTok, err := o.tc.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{scope},
TenantID: o.tenant,
EnableCAE: true,
})
if err != nil {
return nil, fmt.Errorf("failed fetching new AccessToken: %w", err)
}
o.tokens[scope] = &newTok
o.mut.Unlock()
tok = &newTok
}
return &AzCoreAccessToken{tok, o, scope}, nil
}
type AccessToken interface {
FreshToken() (string, error)
CurrentToken() string
}
type AzCoreAccessToken struct {
tok *azcore.AccessToken
parent *OAuthCache
Scope string // this is bad design but maybe it's right
}
// FreshToken attempts to cleanly get a token.
func (a *AzCoreAccessToken) FreshToken() (string, error) {
if time.Now().Add(time.Minute).After(a.tok.ExpiresOn) {
newTok, err := a.parent.GetAccessToken(a.Scope)
if err != nil {
return "", fmt.Errorf("failed to refresh token: %w", err)
}
a.tok = newTok.tok
}
return a.tok.Token, nil
}
func (a *AzCoreAccessToken) CurrentToken() string {
return a.tok.Token
}