common/oauthTokenManager.go (600 lines of code) (raw):
// Copyright © 2017 Microsoft <wastore@microsoft.com>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package common
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest/date"
// importing the cache module registers the cache implementation for the current platform
_ "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache"
)
// ApplicationID represents 1st party ApplicationID for AzCopy.
// const ApplicationID = "a45c21f4-7066-40b4-97d8-14f4313c3caa" // 3rd party test ApplicationID for AzCopy.
const ApplicationID = "579a7132-0e58-4d80-b1e1-7a1e2d337859"
// Resource used in azure storage OAuth authentication
const Resource = "https://storage.azure.com"
const MDResource = "https://disk.azure.com/" // There must be a trailing slash-- The service checks explicitly for "https://disk.azure.com/"
const StorageScope = "https://storage.azure.com/.default"
const ManagedDiskScope = "https://disk.azure.com//.default" // There must be a trailing slash-- The service checks explicitly for "https://disk.azure.com/"
const DefaultTenantID = "common"
const DefaultActiveDirectoryEndpoint = "https://login.microsoftonline.com"
const TokenCache = "AzCopyTokenCache"
type CredCacheImplementation interface {
HasCachedToken() (bool, error)
LoadToken() (*OAuthTokenInfo, error)
SaveToken(OAuthTokenInfo) error
RemoveCachedToken() error
}
// UserOAuthTokenManager for token management.
type UserOAuthTokenManager struct {
oauthClient *http.Client
credCache CredCacheImplementation
// Stash the credential info as we delete the environment variable after reading it, and we need to get it multiple times.
stashedInfo *OAuthTokenInfo
}
// NewUserOAuthTokenManagerInstance creates a token manager instance.
func NewUserOAuthTokenManagerInstance(credCacheOptions CredCacheOptions) *UserOAuthTokenManager {
return &UserOAuthTokenManager{
oauthClient: newAzcopyHTTPClient(),
credCache: NewCredCache(credCacheOptions),
}
}
func newAzcopyHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
Proxy: GlobalProxyLookup,
// We use Dial instead of DialContext as DialContext has been reported to cause slower performance.
Dial /*Context*/ : (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 10 * time.Second,
DualStack: true,
}).Dial, /*Context*/
MaxIdleConns: 0, // No limit
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 180 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableKeepAlives: false,
DisableCompression: true,
MaxResponseHeaderBytes: 0,
// ResponseHeaderTimeout: time.Duration{},
// ExpectContinueTimeout: time.Duration{},
},
}
}
// GetTokenInfo gets token info, it follows rule:
// 1. If there is token passed from environment variable(note this is only for testing purpose),
// use token passed from environment variable.
// 2. Otherwise, try to get token from cache.
//
// This method either successfully return token, or return error.
func (uotm *UserOAuthTokenManager) GetTokenInfo(ctx context.Context) (*OAuthTokenInfo, error) {
if uotm.stashedInfo != nil {
return uotm.stashedInfo, nil
}
var tokenInfo *OAuthTokenInfo
var err error
if tokenInfo, err = uotm.getTokenInfoFromEnvVar(ctx); err == nil || !IsErrorEnvVarOAuthTokenInfoNotSet(err) {
// Scenario-Test: unattended testing with oauthTokenInfo set through environment variable
// Note: Whenever environment variable is set in the context, it will overwrite the cached token info.
if err != nil { // this is the case when env var exists while get token info failed
return nil, err
}
} else { // Scenario: session mode which get token from cache
if tokenInfo, err = uotm.getCachedTokenInfo(ctx); err != nil {
return nil, err
}
}
if tokenInfo == nil || tokenInfo.IsEmpty() {
return nil, errors.New("invalid state, cannot get valid token info")
}
uotm.stashedInfo = tokenInfo
return tokenInfo, nil
}
func (uotm *UserOAuthTokenManager) validateAndPersistLogin(oAuthTokenInfo *OAuthTokenInfo) error {
// Use default tenant ID and active directory endpoint, if nothing specified.
if oAuthTokenInfo.Tenant == "" {
oAuthTokenInfo.Tenant = DefaultTenantID
}
if oAuthTokenInfo.ActiveDirectoryEndpoint == "" {
oAuthTokenInfo.ActiveDirectoryEndpoint = DefaultActiveDirectoryEndpoint
}
tc, err := oAuthTokenInfo.GetTokenCredential()
if err != nil {
return err
}
scopes := []string{StorageScope}
_, err = tc.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: scopes, EnableCAE: true})
if err != nil {
return err
}
uotm.stashedInfo = oAuthTokenInfo
if oAuthTokenInfo.Persist {
err = uotm.credCache.SaveToken(*oAuthTokenInfo)
if err != nil {
return err
}
}
return nil
}
func (uotm *UserOAuthTokenManager) WorkloadIdentityLogin(persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.Workload(),
Persist: persist,
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
func (uotm *UserOAuthTokenManager) AzCliLogin(tenantID string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.AzCLI(),
Tenant: tenantID,
Persist: persist, // AzCLI creds do not need to be persisted, AzCLI handles persistence.
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
func (uotm *UserOAuthTokenManager) PSContextToken(tenantID string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.PsCred(),
Tenant: tenantID,
Persist: persist, // Powershell creds do not need to be persisted, Powershell handles persistence.
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
// MSILogin tries to get token from MSI, persist indicates whether to cache the token on local disk.
func (uotm *UserOAuthTokenManager) MSILogin(identityInfo IdentityInfo, persist bool) error {
if err := identityInfo.Validate(); err != nil {
return err
}
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.MSI(),
IdentityInfo: identityInfo,
Persist: persist,
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
// SecretLogin is a UOTM shell for secretLoginNoUOTM.
func (uotm *UserOAuthTokenManager) SecretLogin(tenantID, activeDirectoryEndpoint, secret, applicationID string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.SPN(),
Tenant: tenantID,
ActiveDirectoryEndpoint: activeDirectoryEndpoint,
ApplicationID: applicationID,
SPNInfo: SPNInfo{
Secret: secret,
CertPath: "",
},
Persist: persist,
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
// CertLogin non-interactively logs in using a specified certificate, certificate password, and activedirectory endpoint.
func (uotm *UserOAuthTokenManager) CertLogin(tenantID, activeDirectoryEndpoint, certPath, certPass, applicationID string, persist bool) error {
absCertPath, _ := filepath.Abs(certPath)
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.SPN(),
Tenant: tenantID,
ActiveDirectoryEndpoint: activeDirectoryEndpoint,
ApplicationID: applicationID,
SPNInfo: SPNInfo{
Secret: certPass,
CertPath: absCertPath,
},
Persist: persist,
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
// UserLogin interactively logins in with specified tenantID and activeDirectoryEndpoint, persist indicates whether to
// cache the token on local disk.
func (uotm *UserOAuthTokenManager) UserLogin(tenantID, activeDirectoryEndpoint string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.Device(),
Tenant: tenantID,
ActiveDirectoryEndpoint: activeDirectoryEndpoint,
ApplicationID: ApplicationID,
DeviceCodeInfo: &azidentity.AuthenticationRecord{},
Persist: persist,
}
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}
// getCachedTokenInfo get a fresh token from local disk cache.
// If access token is expired, it will refresh the token.
// If refresh token is expired, the method will fail and return failure reason.
// Fresh token is persisted if access token or refresh token is changed.
func (uotm *UserOAuthTokenManager) getCachedTokenInfo(ctx context.Context) (*OAuthTokenInfo, error) {
hasToken, err := uotm.credCache.HasCachedToken()
if err != nil {
return nil, fmt.Errorf("no cached token found, please log in with azcopy's login command, %w", err)
}
if !hasToken {
return nil, errors.New("no cached token found, please log in with azcopy's login command")
}
tokenInfo, err := uotm.credCache.LoadToken()
if err != nil {
return nil, fmt.Errorf("get cached token failed, %w", err)
}
freshToken, err := tokenInfo.Refresh(ctx)
if err != nil {
return nil, fmt.Errorf("get cached token failed to ensure token fresh, please log in with azcopy's login command again, %w", err)
}
// Update token cache, if token is updated.
if freshToken.AccessToken != tokenInfo.AccessToken || freshToken.RefreshToken != tokenInfo.RefreshToken {
tokenInfo.Token = *freshToken
if err := uotm.credCache.SaveToken(*tokenInfo); err != nil {
return nil, err
}
}
return tokenInfo, nil
}
// HasCachedToken returns if there is cached token in token manager.
func (uotm *UserOAuthTokenManager) HasCachedToken() (bool, error) {
if uotm.stashedInfo != nil {
return true, nil
}
return uotm.credCache.HasCachedToken()
}
// RemoveCachedToken delete all the cached token.
func (uotm *UserOAuthTokenManager) RemoveCachedToken() error {
return uotm.credCache.RemoveCachedToken()
}
// ====================================================================================
// EnvVarOAuthTokenInfo passes oauth token info into AzCopy through environment variable.
// Note: this is only used for testing, and not encouraged to be used in production environments.
const EnvVarOAuthTokenInfo = "AZCOPY_OAUTH_TOKEN_INFO"
// ErrorCodeEnvVarOAuthTokenInfoNotSet defines error code when environment variable AZCOPY_OAUTH_TOKEN_INFO is not set.
const ErrorCodeEnvVarOAuthTokenInfoNotSet = "environment variable AZCOPY_OAUTH_TOKEN_INFO is not set"
var stashedEnvOAuthTokenExists = false
// EnvVarOAuthTokenInfoExists verifies if environment variable for OAuthTokenInfo is specified.
// The method returns true if the environment variable is set.
// Note: This is useful for only checking whether the env var exists, please use getTokenInfoFromEnvVar
// directly in the case getting token info is necessary.
func EnvVarOAuthTokenInfoExists() bool {
if GetEnvironmentVariable(EEnvironmentVariable.OAuthTokenInfo()) == "" && !stashedEnvOAuthTokenExists {
return false
}
stashedEnvOAuthTokenExists = true
return true
}
// IsErrorEnvVarOAuthTokenInfoNotSet verifies if an error indicates environment variable AZCOPY_OAUTH_TOKEN_INFO is not set.
func IsErrorEnvVarOAuthTokenInfoNotSet(err error) bool {
if err != nil && strings.Contains(err.Error(), ErrorCodeEnvVarOAuthTokenInfoNotSet) {
return true
}
return false
}
// getTokenInfoFromEnvVar gets token info from environment variable.
func (uotm *UserOAuthTokenManager) getTokenInfoFromEnvVar(ctx context.Context) (*OAuthTokenInfo, error) {
rawToken := GetEnvironmentVariable(EEnvironmentVariable.OAuthTokenInfo())
if rawToken == "" {
return nil, errors.New(ErrorCodeEnvVarOAuthTokenInfoNotSet)
}
// Remove the env var after successfully fetching once,
// in case of env var is further spreading into child processes unexpectedly.
ClearEnvironmentVariable(EEnvironmentVariable.OAuthTokenInfo())
tokenInfo, err := jsonToTokenInfo([]byte(rawToken))
if err != nil {
return nil, fmt.Errorf("get token from environment variable failed to unmarshal token, %w", err)
}
if tokenInfo.LoginType != EAutoLoginType.TokenStore() {
refreshedToken, err := tokenInfo.Refresh(ctx)
if err != nil {
return nil, fmt.Errorf("get token from environment variable failed to ensure token fresh, %w", err)
}
tokenInfo.Token = *refreshedToken
}
return tokenInfo, nil
}
// ====================================================================================
// OAuthTokenInfo contains info necessary for refresh OAuth credentials.
type OAuthTokenInfo struct {
azcore.TokenCredential `json:"-"`
Token
Tenant string `json:"_tenant"`
ActiveDirectoryEndpoint string `json:"_ad_endpoint"`
LoginType AutoLoginType `json:"_token_refresh_source"`
ApplicationID string `json:"_application_id"`
IdentityInfo IdentityInfo
SPNInfo SPNInfo
// Note: ClientID should be only used for internal integrations through env var with refresh token.
// It indicates the Application ID assigned to your app when you registered it with Azure AD.
// In this case AzCopy refresh token on behalf of caller.
// For more details, please refer to
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-protocols-oauth-code#refreshing-the-access-tokens
ClientID string `json:"_client_id"`
DeviceCodeInfo *azidentity.AuthenticationRecord `json:"_authentication_record,omitempty"`
Persist bool `json:"_persist"`
}
// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
NotBefore json.Number `json:"not_before"`
Resource string `json:"resource"`
Type string `json:"token_type"`
}
// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}
// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := t.ExpiresOn.Float64()
if err != nil {
s = -3600
}
expiration := date.NewUnixTimeFromSeconds(s)
return time.Time(expiration).UTC()
}
// IsExpired returns true if the Token is expired, false otherwise.
func (t Token) IsExpired() bool {
return !t.Expires().After(time.Now().Add(0))
}
// IdentityInfo contains info for MSI.
type IdentityInfo struct {
ClientID string `json:"_identity_client_id"`
ObjectID string `json:"_identity_object_id"`
MSIResID string `json:"_identity_msi_res_id"`
}
// SPNInfo contains info for authenticating with Service Principal Names
type SPNInfo struct {
// Secret is used for two purposes: The certificate secret, and a client secret.
// The secret is persisted to the JSON file because AAD does not issue a refresh token.
// Thus, the original secret is needed to refresh.
Secret string `json:"_spn_secret"`
CertPath string `json:"_spn_cert_path"`
}
// Validate validates identity info, at most only one of clientID, objectID or MSI resource ID could be set.
func (identityInfo *IdentityInfo) Validate() error {
v := make(map[string]bool, 3)
if identityInfo.ClientID != "" {
v[identityInfo.ClientID] = true
}
if identityInfo.ObjectID != "" {
v[identityInfo.ObjectID] = true
}
if identityInfo.MSIResID != "" {
v[identityInfo.MSIResID] = true
}
if len(v) > 1 {
return errors.New("client ID, object ID and MSI resource ID are mutually exclusive")
}
return nil
}
// Refresh gets new token with token info.
func (credInfo *OAuthTokenInfo) Refresh(ctx context.Context) (*Token, error) {
// TODO: I think this method is only necessary until datalake is migrated.
// Returns cached TokenCredential or creates a new one if it hasn't been created yet.
tc, err := credInfo.GetTokenCredential()
if err != nil {
return nil, err
}
scopes := []string{StorageScope}
t, err := tc.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes, EnableCAE: true})
if err != nil {
return nil, err
}
return &Token{
AccessToken: t.Token,
ExpiresOn: json.Number(strconv.FormatInt(int64(t.ExpiresOn.Sub(date.UnixEpoch())/time.Second), 10)),
}, nil
}
// Single instance token store credential cache shared by entire azcopy process.
var tokenStoreCredCache CredCacheImplementation = NewCredCacheInternalIntegration(CredCacheOptions{
KeyName: "azcopy/aadtoken/" + strconv.Itoa(os.Getpid()),
ServiceName: "azcopy",
AccountName: "aadtoken/" + strconv.Itoa(os.Getpid()),
})
// IsEmpty returns if current OAuthTokenInfo is empty and doesn't contain any useful info.
func (credInfo OAuthTokenInfo) IsEmpty() bool {
if credInfo.Tenant == "" && credInfo.ActiveDirectoryEndpoint == "" && credInfo.Token.IsZero() {
return true
}
return false
}
// toJSON converts OAuthTokenInfo to json format.
func (credInfo OAuthTokenInfo) toJSON() ([]byte, error) {
return json.Marshal(credInfo)
}
func getAuthorityURL(activeDirectoryEndpoint string) (*url.URL, error) {
u, err := url.Parse(activeDirectoryEndpoint)
if err != nil {
return nil, err
}
return u, nil
}
const minimumTokenValidDuration = time.Minute * 5
type TokenStoreCredential struct {
token *azcore.AccessToken
lock sync.RWMutex
credCache CredCacheImplementation
}
// globalTokenStoreCredential is created to make sure that all
// service clients share same cred object. This is required so that
// we do not make repeated GetToken calls.
// This is a temporary fix for issue where we would request a
// new token from Stg Exp even while they've not yet populated the
// tokenstore.
//
// This is okay because we use same credential on both source and
// destination. If we move to a case where the credentials are
// different, this should be removed.
//
// We should move to a method where the token is always read from
// tokenstore, and azcopy is invoked after tokenstore is populated.
var globalTokenStoreCredential *TokenStoreCredential
var globalTsc sync.Once
func (tsc *TokenStoreCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
// if the token we have has not expired, return the same.
tsc.lock.RLock()
if rem := time.Until(tsc.token.ExpiresOn); rem > minimumTokenValidDuration {
tsc.lock.RUnlock() // return path, so we must release the read lock here as well.
return *tsc.token, nil
}
tsc.lock.RUnlock()
tsc.lock.Lock()
defer tsc.lock.Unlock()
hasToken, err := tsc.credCache.HasCachedToken()
if err != nil || !hasToken {
AzcopyCurrentJobLogger.Log(LogDebug, fmt.Sprintf("no token found %v", err))
return azcore.AccessToken{}, fmt.Errorf("no cached token found in Token Store Mode(SE), %w", err)
}
tokenInfo, err := tsc.credCache.LoadToken()
if err != nil {
AzcopyCurrentJobLogger.Log(LogDebug, fmt.Sprintf("get token failed %s", err.Error()))
return azcore.AccessToken{}, fmt.Errorf("get cached token failed in Token Store Mode(SE), %w", err)
}
tsc.token = &azcore.AccessToken{
Token: tokenInfo.AccessToken,
ExpiresOn: tokenInfo.Expires(),
}
return *tsc.token, nil
}
// GetTokenStoreCredential gets token from token store. (Credential Manager in Windows, keyring in Linux and keychain in MacOS.)
// Note: This approach should only be used in internal integrations.
func GetTokenStoreCredential(accessToken string, expiresOn time.Time) azcore.TokenCredential {
globalTsc.Do(func() {
globalTokenStoreCredential = &TokenStoreCredential{
token: &azcore.AccessToken{
Token: accessToken,
ExpiresOn: expiresOn,
},
credCache: tokenStoreCredCache,
}
})
return globalTokenStoreCredential
}
func (credInfo *OAuthTokenInfo) GetTokenStoreCredential() (azcore.TokenCredential, error) {
credInfo.TokenCredential = GetTokenStoreCredential(credInfo.AccessToken, credInfo.Expires())
return credInfo.TokenCredential, nil
}
func (credInfo *OAuthTokenInfo) GetManagedIdentityCredential() (azcore.TokenCredential, error) {
var id azidentity.ManagedIDKind
if credInfo.IdentityInfo.ClientID != "" {
id = azidentity.ClientID(credInfo.IdentityInfo.ClientID)
} else if credInfo.IdentityInfo.MSIResID != "" {
id = azidentity.ResourceID(credInfo.IdentityInfo.MSIResID)
} else if credInfo.IdentityInfo.ObjectID != "" {
return nil, fmt.Errorf("object ID is deprecated and no longer supported for managed identity. Please use client ID or resource ID instead")
}
tc, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: newAzcopyHTTPClient(),
},
ID: id,
})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetClientCertificateCredential() (azcore.TokenCredential, error) {
authorityHost, err := getAuthorityURL(credInfo.ActiveDirectoryEndpoint)
if err != nil {
return nil, err
}
certData, err := os.ReadFile(credInfo.SPNInfo.CertPath)
if err != nil {
return nil, err
}
certs, key, err := azidentity.ParseCertificates(certData, []byte(credInfo.SPNInfo.Secret))
if err != nil {
return nil, err
}
tc, err := azidentity.NewClientCertificateCredential(credInfo.Tenant, credInfo.ApplicationID, certs, key, &azidentity.ClientCertificateCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()},
Transport: newAzcopyHTTPClient(),
},
})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetClientSecretCredential() (azcore.TokenCredential, error) {
authorityHost, err := getAuthorityURL(credInfo.ActiveDirectoryEndpoint)
if err != nil {
return nil, err
}
tc, err := azidentity.NewClientSecretCredential(credInfo.Tenant, credInfo.ApplicationID, credInfo.SPNInfo.Secret, &azidentity.ClientSecretCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()},
Transport: newAzcopyHTTPClient(),
},
})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetAzCliCredential() (azcore.TokenCredential, error) {
if credInfo.Tenant == DefaultTenantID {
credInfo.Tenant = ""
}
tc, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{TenantID: credInfo.Tenant})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetPSContextCredential() (azcore.TokenCredential, error) {
if credInfo.Tenant == DefaultTenantID {
credInfo.Tenant = ""
}
tc, err := NewPowershellContextCredential(&PowershellContextCredentialOptions{TenantID: credInfo.Tenant})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetWorkloadIdentityCredential() (azcore.TokenCredential, error) {
tc, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: newAzcopyHTTPClient(),
},
})
if err != nil {
return nil, err
}
credInfo.TokenCredential = tc
return tc, nil
}
func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredential, error) {
authorityHost, err := getAuthorityURL(credInfo.ActiveDirectoryEndpoint)
if err != nil {
return nil, err
}
var persistentCache azidentity.Cache
if credInfo.Persist {
persistentCache, err = cache.New(&cache.Options{
Name: TokenCache,
})
if err != nil {
return nil, err
}
}
// Read the record
record := IffNotNil(credInfo.DeviceCodeInfo, azidentity.AuthenticationRecord{})
tc, err := azidentity.NewDeviceCodeCredential(&azidentity.DeviceCodeCredentialOptions{
TenantID: credInfo.Tenant,
ClientID: ApplicationID,
DisableAutomaticAuthentication: true,
Cache: persistentCache,
AuthenticationRecord: record,
ClientOptions: azcore.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()},
Transport: newAzcopyHTTPClient(),
},
UserPrompt: func(ctx context.Context, message azidentity.DeviceCodeMessage) error {
lcm.Info(fmt.Sprintf("Authentication is required. To sign in, open the webpage %s and enter the code %s to authenticate.",
Iff(message.VerificationURL != "", message.VerificationURL, "https://aka.ms/devicelogin"), message.UserCode))
return nil
},
})
if err != nil {
return nil, err
}
if record == (azidentity.AuthenticationRecord{}) {
// No stored record; call Authenticate to acquire one
record, err = tc.Authenticate(context.TODO(), &policy.TokenRequestOptions{
EnableCAE: true,
Scopes: []string{StorageScope},
})
if err != nil {
return nil, err
}
if credInfo.Tenant == DefaultTenantID {
fmt.Println("INFO: Logging in under the \"Common\" tenant. This will log the account in under its home tenant.")
fmt.Println("INFO: If you plan to use AzCopy with a B2B account (where the account's home tenant is separate from the tenant of the target storage account), please sign in under the target tenant with --tenant-id")
}
// Store the record
credInfo.DeviceCodeInfo = &record
}
credInfo.TokenCredential = tc
return tc, nil
}
type AuthenticateToken interface {
azcore.TokenCredential
Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error)
}
func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) {
// Token Credential is cached.
if credInfo.TokenCredential != nil {
return credInfo.TokenCredential, nil
}
if credInfo.LoginType == EAutoLoginType.TokenStore() {
return credInfo.GetTokenStoreCredential()
}
switch credInfo.LoginType {
case EAutoLoginType.MSI():
return credInfo.GetManagedIdentityCredential()
case EAutoLoginType.SPN():
if credInfo.SPNInfo.CertPath != "" {
return credInfo.GetClientCertificateCredential()
} else {
return credInfo.GetClientSecretCredential()
}
case EAutoLoginType.AzCLI():
return credInfo.GetAzCliCredential()
case EAutoLoginType.PsCred():
return credInfo.GetPSContextCredential()
case EAutoLoginType.Workload():
return credInfo.GetWorkloadIdentityCredential()
case EAutoLoginType.Device():
return credInfo.GetDeviceCodeCredential()
default:
return nil, fmt.Errorf("invalid auto-login type specified: %s", credInfo.LoginType)
}
}
// jsonToTokenInfo converts bytes to OAuthTokenInfo
func jsonToTokenInfo(b []byte) (*OAuthTokenInfo, error) {
var OAuthTokenInfo OAuthTokenInfo
if err := json.Unmarshal(b, &OAuthTokenInfo); err != nil {
return nil, err
}
if OAuthTokenInfo.LoginType == EAutoLoginType.TokenStore() {
_, _ = OAuthTokenInfo.GetTokenStoreCredential()
}
return &OAuthTokenInfo, nil
}
// ====================================================================================
// TestOAuthInjection controls variables for OAuth testing injections
type TestOAuthInjection struct {
DoTokenRefreshInjection bool
TokenRefreshDuration time.Duration
}
// GlobalTestOAuthInjection is the global setting for OAuth testing injection control
var GlobalTestOAuthInjection = TestOAuthInjection{
DoTokenRefreshInjection: false,
TokenRefreshDuration: time.Second * 10,
}
// TODO: Add pipeline policy for token refreshing validating.