cli/azd/pkg/infra/provisioning/bicep/prompt.go (391 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package bicep import ( "context" "encoding/json" "fmt" "log" "slices" "strconv" "strings" "sync" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices" "github.com/azure/azure-dev/cli/azd/pkg/account" "github.com/azure/azure-dev/cli/azd/pkg/azure" "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/output" "github.com/azure/azure-dev/cli/azd/pkg/output/ux" "github.com/azure/azure-dev/cli/azd/pkg/password" "github.com/azure/azure-dev/cli/azd/pkg/prompt" "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" ) // promptDialogItemForParameter builds the input.PromptDialogItem for the given required parameter. func (p *BicepProvider) promptDialogItemForParameter( key string, param azure.ArmTemplateParameterDefinition, ) input.PromptDialogItem { help, _ := param.Description() paramType := p.mapBicepTypeToInterfaceType(param.Type) var dialogItem input.PromptDialogItem dialogItem.ID = key dialogItem.DisplayName = key dialogItem.Required = true if help != "" { dialogItem.Description = to.Ptr(help) } if paramType == provisioning.ParameterTypeBoolean { dialogItem.Kind = "select" dialogItem.Choices = []input.PromptDialogChoice{{Value: "true"}, {Value: "false"}} } else if param.AllowedValues != nil { dialogItem.Kind = "select" for _, v := range *param.AllowedValues { dialogItem.Choices = append(dialogItem.Choices, input.PromptDialogChoice{Value: fmt.Sprintf("%v", v)}) } } else if param.Secure() { dialogItem.Kind = "password" } else { dialogItem.Kind = "string" } return dialogItem } func autoGenerate(parameter string, azdMetadata azure.AzdMetadata) (string, error) { if azdMetadata.AutoGenerateConfig == nil { return "", fmt.Errorf("auto generation metadata config is missing for parameter '%s'", parameter) } genValue, err := password.Generate(password.GenerateConfig{ Length: azdMetadata.AutoGenerateConfig.Length, NoLower: azdMetadata.AutoGenerateConfig.NoLower, NoUpper: azdMetadata.AutoGenerateConfig.NoUpper, NoNumeric: azdMetadata.AutoGenerateConfig.NoNumeric, NoSpecial: azdMetadata.AutoGenerateConfig.NoSpecial, MinLower: azdMetadata.AutoGenerateConfig.MinLower, MinUpper: azdMetadata.AutoGenerateConfig.MinUpper, MinNumeric: azdMetadata.AutoGenerateConfig.MinNumeric, MinSpecial: azdMetadata.AutoGenerateConfig.MinSpecial, }) if err != nil { return "", err } return genValue, nil } // locationsWithQuotaFor checks which locations have available quota for a specified list of SKU. // It concurrently queries the Azure API for usage data in each location and filters the results // based on the quota and capacity requirements. // // Parameters: // - ctx: The context for controlling cancellation and deadlines. // - subId: The subscription ID to query against. // - locations: A list of Azure locations to check for quota availability. // - quotaFor: list of SKU name and optional capacity (comma-separated) to check for. Example: "OpenAI.S0.AccountCount, 2" // or ["OpenAI.S1.SkuDescription, 1", "OpenAI.S0.AccountCount, 2"] // // Returns: // - A slice of location strings that have the required quota and capacity available. // - An error if any issues occur during the process or if no locations meet the criteria. // // The function first queries the Azure API for usage data in each location concurrently. // It then filters the results based on the specified list of SKU name and capacity requirements. // If no locations meet the criteria, it returns an error with details about the maximum available capacity found. func (a *BicepProvider) locationsWithQuotaFor( ctx context.Context, subId string, locations []string, quotaFor []string) ([]string, error) { var sharedResults sync.Map var wg sync.WaitGroup azureAiServicesLocations, err := a.azureClient.GetResourceSkuLocations( ctx, subId, "AIServices", "S0", "Standard", "accounts") if err != nil { return nil, fmt.Errorf("getting Azure AI Services locations: %w", err) } if locations == nil { // If no locations are provided, use the Azure AI Services locations locations = azureAiServicesLocations } for _, location := range locations { if !slices.Contains(azureAiServicesLocations, location) { // Skip locations that are not in the list of Azure AI Services locations continue } wg.Add(1) go func(location string) { defer wg.Done() results, err := a.azureClient.GetAiUsages(ctx, subId, location) if err != nil { // log the error but don't return it log.Println("error getting usage for location", location, ":", err) return } sharedResults.Store(location, results) }(location) } wg.Wait() var results []string var iterationError error sharedResults.Range(func(location, quotaDetails any) bool { usages := quotaDetails.([]*armcognitiveservices.Usage) hasS0SkuQuota := slices.ContainsFunc(usages, func(q *armcognitiveservices.Usage) bool { // The minimum quota for the S0 SKU in Microsoft.CognitiveServices/accounts is 2 capacity units return *q.Name.Value == "OpenAI.S0.AccountCount" && (*q.Limit-*q.CurrentValue) >= 2 }) if !hasS0SkuQuota { // If the S0 SKU quota is not available, skip this location return true } // Check if all requested quotas can be satisfied in this location for _, definedUsageName := range quotaFor { usageDetails, err := usageNameDetailsFromString(definedUsageName) if err != nil { iterationError = fmt.Errorf("parsing quota '%s': %w", definedUsageName, err) return false } hasQuotaForModels := slices.ContainsFunc(usages, func(usage *armcognitiveservices.Usage) bool { hasQuota := *usage.Name.Value == usageDetails.UsageName if !hasQuota { return false } remaining := *usage.Limit - *usage.CurrentValue return *usage.Name.Value == usageDetails.UsageName && remaining >= usageDetails.Capacity }) if !hasQuotaForModels { // If the quota for this model is not available, skip this location return true } } // If the quota for this model is available, add the location to the results results = append(results, location.(string)) return true }) if iterationError != nil { return nil, fmt.Errorf("looking for location with quota: %w", iterationError) } if len(results) == 0 { formattedQuota := make([]string, len(quotaFor)) for i, quota := range quotaFor { f, err := usageNameDetailsFromString(quota) if err != nil { return nil, fmt.Errorf("parsing quota '%s': %w", quota, err) } formattedQuota[i] = fmt.Sprintf("%s ( Cap: %.0f )", f.UsageName, f.Capacity) } return nil, fmt.Errorf( "no location found with enough quota for %s", ux.ListAsText(formattedQuota)) } return results, nil } type usageNameDetails struct { UsageName string Capacity float64 } func usageNameDetailsFromString(usageName string) (usageNameDetails, error) { usage := strings.TrimSpace(usageName) if len(usage) == 0 { return usageNameDetails{}, fmt.Errorf("empty usage name") } parts := strings.Split(usage, ",") if len(parts) == 1 { return usageNameDetails{ UsageName: usage, Capacity: 1, }, nil } if len(parts) != 2 { return usageNameDetails{}, fmt.Errorf("invalid usage name format '%s'", usage) } usageName = strings.TrimSpace(parts[0]) capacity, err := strconv.ParseFloat(strings.Trim(parts[1], " "), 64) if err != nil { return usageNameDetails{}, fmt.Errorf("invalid capacity '%s': %w", parts[1], err) } if capacity <= 0 { return usageNameDetails{}, fmt.Errorf("invalid capacity '%.0f': must be greater than 0", capacity) } return usageNameDetails{ UsageName: usageName, Capacity: capacity, }, nil } func (p *BicepProvider) promptForParameter( ctx context.Context, key string, param azure.ArmTemplateParameterDefinition, mappedToAzureLocationParams []string, ) (any, error) { securedParam := "parameter" isSecuredParam := param.Secure() if isSecuredParam { securedParam = "secured parameter" } msg := fmt.Sprintf("Enter a value for the '%s' infrastructure %s:", key, securedParam) help, _ := param.Description() azdMetadata, _ := param.AzdMetadata() paramType := p.mapBicepTypeToInterfaceType(param.Type) var value any if paramType == provisioning.ParameterTypeString && azdMetadata.Type != nil && *azdMetadata.Type == azure.AzdMetadataTypeLocation { // when more than one parameter is mapped to AZURE_LOCATION and AZURE_LOCATION is not set in the environment, // AZD will prompt just once and immediately set the value in the .env for the next parameter to re-use the value paramIsMappedToAzureLocation := slices.Contains(mappedToAzureLocationParams, key) valueFromEnv, valueDefinedInEnv := p.env.LookupEnv(environment.LocationEnvVarName) if paramIsMappedToAzureLocation && valueDefinedInEnv { return valueFromEnv, nil } // location can be combined with allowedValues and with usageName metadata // allowedValues == nil => all locations are allowed // allowedValues != nil => only the locations in the allowedValues are allowed // usageName != nil => the usageName is validated for quota for each allowed location (this is for Ai models), // reducing the allowed locations to only those that have quota available // usageName == nil => No quota validation is done var allowedLocations []string if param.AllowedValues != nil { allowedLocations = make([]string, len(*param.AllowedValues)) for i, option := range *param.AllowedValues { allowedLocations[i] = option.(string) } } if len(azdMetadata.UsageName) > 0 { withQuotaLocations, err := p.locationsWithQuotaFor( ctx, p.env.GetSubscriptionId(), allowedLocations, azdMetadata.UsageName) if err != nil { return nil, fmt.Errorf("getting locations with quota: %w", err) } allowedLocations = withQuotaLocations } location, err := p.prompters.PromptLocation( ctx, p.env.GetSubscriptionId(), msg, func(loc account.Location) bool { return locationParameterFilterImpl(allowedLocations, loc) }, defaultPromptValue(param)) if err != nil { return nil, err } if paramIsMappedToAzureLocation && !valueDefinedInEnv { // set the location in the environment variable p.env.SetLocation(location) if err := p.envManager.Save(ctx, p.env); err != nil { return nil, fmt.Errorf("setting location in environment variable: %w", err) } } value = location } else if paramType == provisioning.ParameterTypeString && azdMetadata.Type != nil && *azdMetadata.Type == azure.AzdMetadataTypeResourceGroup { p.console.Message(ctx, fmt.Sprintf( "Parameter %s requires an %s resource group.", output.WithUnderline("%s", key), output.WithBold("existing"))) rgName, err := p.prompters.PromptResourceGroup(ctx, prompt.PromptResourceOptions{ DisableCreateNew: true, }) if err != nil { return nil, err } value = rgName } else if paramType == provisioning.ParameterTypeString && azdMetadata.Type != nil && *azdMetadata.Type == azure.AzdMetadataTypeGenerateOrManual { var manualUserInput bool defaultOption := "Auto generate" options := []string{defaultOption, "Manual input"} choice, err := p.console.Select(ctx, input.ConsoleOptions{ Message: fmt.Sprintf( "Parameter %s can be either autogenerated or you can enter its value. What would you like to do?", key), Options: options, DefaultValue: defaultOption, }) if err != nil { return nil, err } manualUserInput = options[choice] != defaultOption if manualUserInput { resultValue, err := promptWithValidation(ctx, p.console, input.ConsoleOptions{ Message: msg, Help: help, IsPassword: isSecuredParam, }, convertString, validateLengthRange(key, param.MinLength, param.MaxLength)) if err != nil { return nil, err } value = resultValue } else { genValue, err := autoGenerate(key, azdMetadata) if err != nil { return nil, err } value = genValue } } else if param.AllowedValues != nil { options := make([]string, 0, len(*param.AllowedValues)) for _, option := range *param.AllowedValues { options = append(options, fmt.Sprintf("%v", option)) } if len(options) == 0 { return nil, fmt.Errorf("parameter '%s' has no allowed values defined", key) } choice, err := p.console.Select(ctx, input.ConsoleOptions{ Message: msg, Help: help, Options: options, }) if err != nil { return nil, err } value = (*param.AllowedValues)[choice] } else { switch paramType { case provisioning.ParameterTypeBoolean: options := []string{"False", "True"} choice, err := p.console.Select(ctx, input.ConsoleOptions{ Message: msg, Help: help, Options: options, }) if err != nil { return nil, err } value = (options[choice] == "True") case provisioning.ParameterTypeNumber: userValue, err := promptWithValidation(ctx, p.console, input.ConsoleOptions{ Message: msg, Help: help, }, convertInt, validateValueRange(key, param.MinValue, param.MaxValue)) if err != nil { return nil, err } value = userValue case provisioning.ParameterTypeString: userValue, err := promptWithValidation(ctx, p.console, input.ConsoleOptions{ Message: msg, Help: help, IsPassword: isSecuredParam, }, convertString, validateLengthRange(key, param.MinLength, param.MaxLength)) if err != nil { return nil, err } value = userValue case provisioning.ParameterTypeArray: userValue, err := promptWithValidation(ctx, p.console, input.ConsoleOptions{ Message: msg, Help: help, }, convertJson[[]any], validateJsonArray) if err != nil { return nil, err } value = userValue case provisioning.ParameterTypeObject: userValue, err := promptWithValidation(ctx, p.console, input.ConsoleOptions{ Message: msg, Help: help, }, convertJson[map[string]any], validateJsonObject) if err != nil { return nil, err } value = userValue default: panic(fmt.Sprintf("unknown parameter type: %s", p.mapBicepTypeToInterfaceType(param.Type))) } } return value, nil } // promptWithValidation prompts for a value using the console and then validates that it satisfies all the validation // functions. If it does, it is converted from a string to a value using the converter and returned. If any validation // fails, the prompt is retried after printing the error (prefixed with "Error: ") to the console. If there are is an // error prompting it is returned as is. func promptWithValidation[T any]( ctx context.Context, console input.Console, options input.ConsoleOptions, converter func(string) T, validators ...func(string) error, ) (T, error) { for { userValue, err := console.Prompt(ctx, options) if err != nil { return *new(T), err } isValid := true for _, validator := range validators { if err := validator(userValue); err != nil { console.Message(ctx, output.WithErrorFormat("Error: %s.", err)) isValid = false break } } if isValid { return converter(userValue), nil } } } func convertString(s string) string { return s } func convertInt(s string) int { if i, err := strconv.ParseInt(s, 10, 64); err != nil { panic(fmt.Sprintf("convertInt: %v", err)) } else { return int(i) } } func convertJson[T any](s string) T { var t T if err := json.Unmarshal([]byte(s), &t); err != nil { panic(fmt.Sprintf("convertJson: %v", err)) } return t }