internal/clients/account.go (199 lines of code) (raw):

package clients import ( "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" "log" "os/exec" "strings" "sync" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) type ObjectIDProvider func(ctx context.Context) (string, error) type ResourceManagerAccount struct { tenantId *string subscriptionId *string objectId *string mutex *sync.Mutex objectIDProvider ObjectIDProvider } func NewResourceManagerAccount(tenantId, subscriptionId string, provider ObjectIDProvider) ResourceManagerAccount { out := ResourceManagerAccount{ mutex: &sync.Mutex{}, } if tenantId != "" { out.tenantId = &tenantId } if subscriptionId != "" { out.subscriptionId = &subscriptionId } // We lazy load object ID because it's not always needed and could cause a performance hit out.objectIDProvider = provider return out } func (account *ResourceManagerAccount) GetTenantId() string { account.mutex.Lock() defer account.mutex.Unlock() if account.tenantId != nil { return *account.tenantId } err := account.loadDefaultsFromAzCmd() if err != nil { log.Printf("[DEBUG] Error getting default tenant ID: %s", err) } if account.tenantId == nil { log.Printf("[DEBUG] No default tenant ID found") return "" } return *account.tenantId } func (account *ResourceManagerAccount) GetSubscriptionId() string { account.mutex.Lock() defer account.mutex.Unlock() if account.subscriptionId != nil { return *account.subscriptionId } err := account.loadDefaultsFromAzCmd() if err != nil { log.Printf("[DEBUG] Error getting default subscription ID: %s", err) } if account.subscriptionId == nil { log.Printf("[DEBUG] No subscription ID found") return "" } return *account.subscriptionId } func (account *ResourceManagerAccount) GetObjectId(ctx context.Context) string { account.mutex.Lock() defer account.mutex.Unlock() if account.objectId != nil { return *account.objectId } if account.objectIDProvider != nil { objectId, err := account.objectIDProvider(ctx) if err != nil { log.Printf("[DEBUG] Error getting object ID: %s", err) } if objectId != "" { account.objectId = &objectId return *account.objectId } } err := account.loadSignedInUserFromAzCmd() if err != nil { log.Printf("[DEBUG] Error getting user object ID from az cli: %s", err) } if account.objectId == nil { log.Printf("[DEBUG] No object ID found") return "" } return *account.objectId } func (account *ResourceManagerAccount) loadSignedInUserFromAzCmd() error { var userModel struct { ObjectId string `json:"id"` } err := jsonUnmarshalAzCmd(&userModel, "ad", "signed-in-user", "show") if err != nil { return fmt.Errorf("obtaining defaults from az cmd: %s", err) } account.objectId = &userModel.ObjectId return nil } func (account *ResourceManagerAccount) loadDefaultsFromAzCmd() error { var accountModel struct { SubscriptionID string `json:"id"` TenantId string `json:"tenantId"` } err := jsonUnmarshalAzCmd(&accountModel, "account", "show") if err != nil { return fmt.Errorf("obtaining defaults from az cmd: %s", err) } account.tenantId = &accountModel.TenantId account.subscriptionId = &accountModel.SubscriptionID return nil } // jsonUnmarshalAzCmd executes an Azure CLI command and unmarshalls the JSON output. func jsonUnmarshalAzCmd(i interface{}, arg ...string) error { var stderr bytes.Buffer var stdout bytes.Buffer arg = append(arg, "-o=json") cmd := exec.Command("az", arg...) cmd.Stderr = &stderr cmd.Stdout = &stdout if err := cmd.Start(); err != nil { err := fmt.Errorf("launching Azure CLI: %+v", err) if stdErrStr := stderr.String(); stdErrStr != "" { err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr)) } return err } if err := cmd.Wait(); err != nil { err := fmt.Errorf("running Azure CLI: %+v", err) if stdErrStr := stderr.String(); stdErrStr != "" { err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr)) } return err } if err := json.Unmarshal(stdout.Bytes(), &i); err != nil { return fmt.Errorf("unmarshaling the output of Azure CLI: %v", err) } return nil } func parseTokenClaims(token string) (*tokenClaims, error) { // Parse the token to get the claims parts := strings.Split(token, ".") if len(parts) != 3 { return nil, errors.New("parseTokenClaims: token does not have 3 parts") } decoded, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("parseTokenClaims: error decoding token: %s", err) } var claims tokenClaims err = json.Unmarshal(decoded, &claims) if err != nil { return nil, fmt.Errorf("parseTokenClaims: error unmarshalling claims: %w", err) } return &claims, nil } type tokenClaims struct { Audience string `json:"aud"` Expires int64 `json:"exp"` IssuedAt int64 `json:"iat"` Issuer string `json:"iss"` IdentityProvider string `json:"idp"` ObjectId string `json:"oid"` Roles []string `json:"roles"` Scopes string `json:"scp"` Subject string `json:"sub"` TenantRegionScope string `json:"tenant_region_scope"` TenantId string `json:"tid"` Version string `json:"ver"` AppDisplayName string `json:"app_displayname,omitempty"` AppId string `json:"appid,omitempty"` IdType string `json:"idtyp,omitempty"` } func ParsedTokenClaimsObjectIDProvider(cred azcore.TokenCredential, cloudCfg cloud.Configuration) ObjectIDProvider { return func(ctx context.Context) (string, error) { tok, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{ EnableCAE: true, Scopes: []string{cloudCfg.Services[cloud.ResourceManager].Audience + "/.default"}}) if err != nil { return "", fmt.Errorf("getting requesting token from credentials: %w", err) } if tok.Token == "" { return "", errors.New("token is empty") } cl, err := parseTokenClaims(tok.Token) if err != nil { return "", fmt.Errorf("getting object id from token: %w", err) } if cl == nil || cl.ObjectId == "" { return "", errors.New("object id is empty") } return cl.ObjectId, nil } }