cli/azd/pkg/entraid/entraid.go (521 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package entraid
import (
"context"
"errors"
"fmt"
"log"
"net/http"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2"
"github.com/azure/azure-dev/cli/azd/internal"
"github.com/azure/azure-dev/cli/azd/pkg/account"
"github.com/azure/azure-dev/cli/azd/pkg/azure"
"github.com/azure/azure-dev/cli/azd/pkg/graphsdk"
"github.com/google/uuid"
"github.com/sethvargo/go-retry"
)
const (
federatedIdentityIssuer = "https://token.actions.githubusercontent.com"
federatedIdentityAudience = "api://AzureADTokenExchange"
)
// Required model structure for Azure Credentials tools
type AzureCredentials struct {
ClientId string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
SubscriptionId string `json:"subscriptionId"`
TenantId string `json:"tenantId"`
}
// EntraIdService provides actions on top of Azure Active Directory (AD)
type EntraIdService interface {
GetServicePrincipal(
ctx context.Context,
subscriptionId string,
appIdOrName string,
) (*graphsdk.ServicePrincipal, error)
CreateOrUpdateServicePrincipal(
ctx context.Context,
subscriptionId string,
appIdOrName string,
options CreateOrUpdateServicePrincipalOptions,
) (*graphsdk.ServicePrincipal, error)
ResetPasswordCredentials(
ctx context.Context,
subscriptionId string,
appId string,
) (*AzureCredentials, error)
ApplyFederatedCredentials(
ctx context.Context,
subscriptionId string,
clientId string,
federatedCredentials []*graphsdk.FederatedIdentityCredential,
) ([]*graphsdk.FederatedIdentityCredential, error)
CreateRbac(ctx context.Context, subscriptionId string, scope, roleId, principalId string) error
}
type entraIdService struct {
credentialProvider account.SubscriptionCredentialProvider
clientCache map[string]*graphsdk.GraphClient
armClientOptions *arm.ClientOptions
coreClientOptions *azcore.ClientOptions
}
// Creates a new instance of the EntraIdService
func NewEntraIdService(
credentialProvider account.SubscriptionCredentialProvider,
armClientOptions *arm.ClientOptions,
coreClientOptions *azcore.ClientOptions,
) EntraIdService {
return &entraIdService{
credentialProvider: credentialProvider,
clientCache: map[string]*graphsdk.GraphClient{},
armClientOptions: armClientOptions,
coreClientOptions: coreClientOptions,
}
}
// GetServicePrincipal gets the service principal for the specified application ID or name
func (ad *entraIdService) GetServicePrincipal(
ctx context.Context,
subscriptionId string,
appIdOrName string,
) (*graphsdk.ServicePrincipal, error) {
application, err := ad.getApplicationByNameOrId(ctx, subscriptionId, appIdOrName)
if err != nil {
return nil, err
}
return ad.getServicePrincipal(ctx, subscriptionId, application)
}
type CreateOrUpdateServicePrincipalOptions struct {
RolesToAssign []string
Description *string
ServiceManagementReference *string
}
func (ad *entraIdService) CreateOrUpdateServicePrincipal(
ctx context.Context,
subscriptionId string,
appIdOrName string,
options CreateOrUpdateServicePrincipalOptions,
) (*graphsdk.ServicePrincipal, error) {
var application *graphsdk.Application
var err error
// Attempt to find existing application by ID or name
application, _ = ad.getApplicationByNameOrId(ctx, subscriptionId, appIdOrName)
// Create new application if not found
if application == nil {
// Create application
application, err = ad.createApplication(ctx, subscriptionId, appIdOrName, options)
if err != nil {
return nil, err
}
}
// Get or create service principal from application
servicePrincipal, err := ad.ensureServicePrincipal(ctx, subscriptionId, application)
if err != nil {
return nil, err
}
// Apply specified role assignments
err = ad.ensureRoleAssignments(ctx, subscriptionId, options.RolesToAssign, servicePrincipal)
if err != nil {
return nil, fmt.Errorf("failed applying role assignment: %w", err)
}
return servicePrincipal, nil
}
// Removes any existing password credentials from the application
// and creates a new password credential
func (ad *entraIdService) ResetPasswordCredentials(
ctx context.Context,
subscriptionId string,
appId string,
) (*AzureCredentials, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
application, err := ad.getApplicationByAppId(ctx, subscriptionId, appId)
if err != nil {
return nil, fmt.Errorf("failed finding matching application: %w", err)
}
servicePrincipal, err := ad.getServicePrincipal(ctx, subscriptionId, application)
if err != nil {
return nil, fmt.Errorf("failed finding matching service principal: %w", err)
}
for _, credential := range application.PasswordCredentials {
err := graphClient.
ApplicationById(*application.Id).
RemovePassword(ctx, *credential.KeyId)
if err != nil {
return nil, fmt.Errorf("failed removing credentials for KeyId '%s' : %w", *credential.KeyId, err)
}
}
credential, err := graphClient.
ApplicationById(*application.Id).
AddPassword(ctx)
if err != nil {
return nil, fmt.Errorf(
"failed adding new password credential for application '%s' : %w",
application.DisplayName,
err,
)
}
return &AzureCredentials{
ClientId: *application.AppId,
ClientSecret: *credential.SecretText,
SubscriptionId: subscriptionId,
TenantId: *servicePrincipal.AppOwnerOrganizationId,
}, nil
}
func (ad *entraIdService) ApplyFederatedCredentials(
ctx context.Context,
subscriptionId string,
clientId string,
federatedCredentials []*graphsdk.FederatedIdentityCredential,
) ([]*graphsdk.FederatedIdentityCredential, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
application, err := ad.getApplicationByAppId(ctx, subscriptionId, clientId)
if err != nil {
return nil, fmt.Errorf("failed finding matching application: %w", err)
}
existingCredsResponse, err := graphClient.
ApplicationById(*application.Id).
FederatedIdentityCredentials().
Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed retrieving federated credentials: %w", err)
}
existingCredentials := existingCredsResponse.Value
createdCredentials := []*graphsdk.FederatedIdentityCredential{}
// Ensure the credential exists otherwise create a new one.
for i := range federatedCredentials {
credential, err := ad.ensureFederatedCredential(
ctx,
subscriptionId,
application,
existingCredentials,
federatedCredentials[i],
)
if err != nil {
return nil, err
}
if credential != nil {
createdCredentials = append(createdCredentials, credential)
}
}
return createdCredentials, nil
}
func (ad *entraIdService) getApplicationByNameOrId(
ctx context.Context,
subscriptionId string,
appIdOrName string,
) (*graphsdk.Application, error) {
// Attempt to find existing application by ID
application, _ := ad.getApplicationByAppId(ctx, subscriptionId, appIdOrName)
// Fallback to find by name
if application == nil {
application, _ = ad.getApplicationByName(ctx, subscriptionId, appIdOrName)
}
if application == nil {
return nil, fmt.Errorf("could not find application with ID or name '%s'", appIdOrName)
}
return application, nil
}
func (ad *entraIdService) getApplicationByAppId(
ctx context.Context,
subscriptionId string,
appId string,
) (*graphsdk.Application, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
application, err := graphClient.
ApplicationById(appId).
GetByAppId(ctx)
if err != nil {
return nil, fmt.Errorf("failed retrieving application with id '%s': %w", appId, err)
}
return application, nil
}
func (ad *entraIdService) getApplicationByName(
ctx context.Context,
subscriptionId string,
applicationName string,
) (*graphsdk.Application, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
matchingItems, err := graphClient.
Applications().
Filter(fmt.Sprintf("startswith(displayName, '%s')", applicationName)).
Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed retrieving application list: %w", err)
}
if (len(matchingItems.Value)) == 0 {
return nil, fmt.Errorf("no application with name '%s' found", applicationName)
}
if len(matchingItems.Value) > 1 {
return nil, fmt.Errorf("more than 1 application with same name '%s'", applicationName)
}
return &matchingItems.Value[0], nil
}
// Gets or creates an application with the specified name
func (ad *entraIdService) createApplication(
ctx context.Context,
subscriptionId string,
applicationName string,
options CreateOrUpdateServicePrincipalOptions,
) (*graphsdk.Application, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
// Existing application doesn't exist - create a new one
newApp := &graphsdk.Application{
DisplayName: applicationName,
Description: options.Description,
PasswordCredentials: []*graphsdk.ApplicationPasswordCredential{},
ServiceManagementReference: options.ServiceManagementReference,
}
newApp, err = graphClient.Applications().Post(ctx, newApp)
if err != nil {
return nil, fmt.Errorf("failed creating application '%s': %w", applicationName, err)
}
return newApp, nil
}
func (ad *entraIdService) getServicePrincipal(
ctx context.Context,
subscriptionId string,
application *graphsdk.Application,
) (*graphsdk.ServicePrincipal, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
matchingItems, err := graphClient.
ServicePrincipals().
Filter(fmt.Sprintf("displayName eq '%s'", application.DisplayName)).
Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed retrieving application list: %w", err)
}
if len(matchingItems.Value) > 1 {
return nil, fmt.Errorf("more than 1 application exists with same name '%s'", application.DisplayName)
}
if len(matchingItems.Value) == 1 {
return &matchingItems.Value[0], nil
}
return nil, fmt.Errorf("no service principal found for application '%s'", application.DisplayName)
}
// Gets or creates a service principal for the specified application name
func (ad *entraIdService) ensureServicePrincipal(
ctx context.Context,
subscriptionId string,
application *graphsdk.Application,
) (*graphsdk.ServicePrincipal, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
servicePrincipal, err := ad.getServicePrincipal(ctx, subscriptionId, application)
if err == nil && servicePrincipal != nil {
return servicePrincipal, nil
}
// Existing service principal doesn't exist - create a new one.
newSpn := &graphsdk.ServicePrincipal{
AppId: *application.AppId,
DisplayName: application.DisplayName,
Description: application.Description,
}
newSpn, err = graphClient.ServicePrincipals().Post(ctx, newSpn)
if err != nil {
return nil, fmt.Errorf("failed creating service principal '%s': %w", application.DisplayName, err)
}
return newSpn, nil
}
// Ensures that the federated credential exists on the application otherwise create a new one
func (ad *entraIdService) ensureFederatedCredential(
ctx context.Context,
subscriptionId string,
application *graphsdk.Application,
existingCredentials []graphsdk.FederatedIdentityCredential,
repoCredential *graphsdk.FederatedIdentityCredential,
) (*graphsdk.FederatedIdentityCredential, error) {
graphClient, err := ad.getOrCreateGraphClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
// If a federated credential already exists for the same subject then nothing to do.
for _, existing := range existingCredentials {
if existing.Subject == repoCredential.Subject {
log.Printf(
"federated credential with subject '%s' already exists on application '%s'",
repoCredential.Subject,
*application.Id,
)
return nil, nil
}
}
// Otherwise create the new federated credential
credential, err := graphClient.
ApplicationById(*application.Id).
FederatedIdentityCredentials().
Post(ctx, repoCredential)
if err != nil {
return nil, fmt.Errorf("failed creating federated credential: %w", err)
}
return credential, nil
}
// Applies the Azure selected RBAC role assignments to the specified service principal
func (ad *entraIdService) ensureRoleAssignments(
ctx context.Context,
subscriptionId string,
roleNames []string,
servicePrincipal *graphsdk.ServicePrincipal,
) error {
for _, roleName := range roleNames {
err := ad.ensureRoleAssignment(ctx, subscriptionId, roleName, servicePrincipal)
if err != nil {
return err
}
}
return nil
}
// Applies the Azure selected RBAC role assignments to the specified service principal
func (ad *entraIdService) ensureRoleAssignment(
ctx context.Context,
subscriptionId string,
roleName string,
servicePrincipal *graphsdk.ServicePrincipal,
) error {
// Find the specified role in the subscription scope
scope := azure.SubscriptionRID(subscriptionId)
roleDefinition, err := ad.getRoleDefinition(ctx, subscriptionId, scope, roleName)
if err != nil {
return err
}
// Create the new role assignment
err = ad.applyRoleAssignmentWithRetry(ctx, subscriptionId, roleDefinition, servicePrincipal)
if err != nil {
return err
}
return nil
}
func (ad *entraIdService) CreateRbac(
ctx context.Context, subscriptionId string, scope, roleId, principalId string) error {
fullRoleId := fmt.Sprintf("/subscriptions/%s%s", subscriptionId, roleId)
return ad.applyRoleAssignmentWithRetryImpl(
ctx,
subscriptionId,
scope,
&armauthorization.RoleDefinition{
ID: to.Ptr(fullRoleId),
Name: to.Ptr(roleId),
},
&graphsdk.ServicePrincipal{
Id: to.Ptr(principalId),
})
}
// Applies the role assignment to the specified service principal
// This operation will retry up to 10 times to ensure the new service principal is available in Azure AD
func (ad *entraIdService) applyRoleAssignmentWithRetry(
ctx context.Context,
subscriptionId string,
roleDefinition *armauthorization.RoleDefinition,
servicePrincipal *graphsdk.ServicePrincipal,
) error {
scope := azure.SubscriptionRID(subscriptionId)
return ad.applyRoleAssignmentWithRetryImpl(ctx, subscriptionId, scope, roleDefinition, servicePrincipal)
}
func (ad *entraIdService) applyRoleAssignmentWithRetryImpl(
ctx context.Context,
subscriptionId string,
scope string,
roleDefinition *armauthorization.RoleDefinition,
servicePrincipal *graphsdk.ServicePrincipal,
) error {
roleAssignmentsClient, err := ad.createRoleAssignmentsClient(ctx, subscriptionId)
if err != nil {
return err
}
roleAssignmentId := uuid.New().String()
// There is a lag in the application/service principal becoming available in Azure AD
// This can cause the role assignment operation to fail
return retry.Do(ctx, retry.WithMaxRetries(10, retry.NewConstant(time.Second*5)), func(ctx context.Context) error {
_, err = roleAssignmentsClient.Create(ctx, scope, roleAssignmentId, armauthorization.RoleAssignmentCreateParameters{
Properties: &armauthorization.RoleAssignmentProperties{
PrincipalID: servicePrincipal.Id,
RoleDefinitionID: roleDefinition.ID,
},
}, nil)
if err != nil {
var responseError *azcore.ResponseError
// If the response is a 409 conflict then the role has already been assigned.
if errors.As(err, &responseError) && responseError.StatusCode == http.StatusConflict {
return nil
}
// If the response is a 403 then the required role is missing.
if errors.As(err, &responseError) && responseError.StatusCode == http.StatusForbidden {
return &internal.ErrorWithSuggestion{
Suggestion: fmt.Sprintf("\nSuggested Action: Ensure you have either the `User Access Administrator`, " +
"Owner` or custom azure roles assigned to your subscription to perform action " +
"'Microsoft.Authorization/roleAssignments/write', in order to manage role assignments\n"),
Err: err,
}
}
return retry.RetryableError(
fmt.Errorf(
"failed assigning role assignment '%s' to service principal '%s' : %w",
*roleDefinition.Name,
servicePrincipal.DisplayName,
err,
),
)
}
return nil
})
}
// Find the Azure role definition for the specified scope and role name
func (ad *entraIdService) getRoleDefinition(
ctx context.Context,
subscriptionId string,
scope string,
roleName string,
) (*armauthorization.RoleDefinition, error) {
roleDefinitionsClient, err := ad.createRoleDefinitionsClient(ctx, subscriptionId)
if err != nil {
return nil, err
}
pager := roleDefinitionsClient.NewListPager(scope, &armauthorization.RoleDefinitionsClientListOptions{
Filter: to.Ptr(fmt.Sprintf("roleName eq '%s'", roleName)),
})
roleDefinitions := []*armauthorization.RoleDefinition{}
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed getting next page of role definitions: %w", err)
}
roleDefinitions = append(roleDefinitions, page.RoleDefinitionListResult.Value...)
}
if len(roleDefinitions) == 0 {
return nil, fmt.Errorf("role definition with scope: '%s' and name: '%s' was not found", scope, roleName)
}
return roleDefinitions[0], nil
}
// Creates a graph users client using credentials from the Go context.
func (ad *entraIdService) createRoleDefinitionsClient(
ctx context.Context,
subscriptionId string,
) (*armauthorization.RoleDefinitionsClient, error) {
credential, err := ad.credentialProvider.CredentialForSubscription(ctx, subscriptionId)
if err != nil {
return nil, err
}
client, err := armauthorization.NewRoleDefinitionsClient(credential, ad.armClientOptions)
if err != nil {
return nil, fmt.Errorf("creating ARM Role Definitions client: %w", err)
}
return client, nil
}
// Creates a graph users client using credentials from the Go context.
func (ad *entraIdService) createRoleAssignmentsClient(
ctx context.Context,
subscriptionId string,
) (*armauthorization.RoleAssignmentsClient, error) {
credential, err := ad.credentialProvider.CredentialForSubscription(ctx, subscriptionId)
if err != nil {
return nil, err
}
client, err := armauthorization.NewRoleAssignmentsClient(subscriptionId, credential, ad.armClientOptions)
if err != nil {
return nil, fmt.Errorf("creating ARM Role Assignments client: %w", err)
}
return client, nil
}
// Creates a graph users client using credentials from the Go context.
func (ad *entraIdService) getOrCreateGraphClient(
ctx context.Context,
subscriptionId string,
) (*graphsdk.GraphClient, error) {
if client, ok := ad.clientCache[subscriptionId]; ok {
return client, nil
}
credential, err := ad.credentialProvider.CredentialForSubscription(ctx, subscriptionId)
if err != nil {
return nil, err
}
client, err := graphsdk.NewGraphClient(credential, ad.coreClientOptions)
if err != nil {
return nil, fmt.Errorf("creating Graph Users client: %w", err)
}
ad.clientCache[subscriptionId] = client
return client, nil
}