internal/provider/provider.go (433 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package provider
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/alzlib"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armpolicy"
"github.com/Azure/terraform-provider-alz/internal/provider/gen"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/function"
"github.com/hashicorp/terraform-plugin-framework/provider"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
"github.com/hashicorp/terraform-plugin-log/tflog"
)
// Run go generate to automatically generate provider, data source and resource types
// from the intermediate representation JSON file `ir.json`.
//go:generate tfplugingen-framework generate provider --package gen --output ./gen
//go:generate tfplugingen-framework generate data-sources --package gen --output ./gen
//go:generate tfplugingen-framework generate resources --package gen --output ./gen
const (
userAgentBase = "AzureTerraformAlzProvider"
alzLibDirBase = ".alzlib"
alzLibRef = "2024.10.1"
alzLibPath = "platform/alz"
)
// Ensure ScaffoldingProvider satisfies various provider interfaces.
var (
_ provider.Provider = &AlzProvider{}
_ provider.ProviderWithFunctions = &AlzProvider{}
)
// AlzProvider defines the provider implementation.
type AlzProvider struct {
// version is set to the provider version on release, "dev" when the
// provider is built and ran locally, and "test" when running acceptance
// testing.
version string
data *alzProviderData
}
type AlzProviderClients struct {
RoleAssignmentsClient *armauthorization.RoleAssignmentsClient
}
type alzProviderData struct {
*alzlib.AlzLib
mu *sync.Mutex
clients *AlzProviderClients
suppressWarningPolicyRoleAssignments bool
}
func (p *AlzProvider) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) {
resp.TypeName = "alz"
resp.Version = p.version
}
func (p *AlzProvider) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
resp.Schema = gen.AlzProviderSchema(ctx)
}
func (p *AlzProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
tflog.Debug(ctx, "Provider configuration started")
if p.data != nil {
tflog.Debug(ctx, "Provider AlzLib already present, skipping configuration")
resp.DataSourceData = p.data
resp.ResourceData = p.data
return
}
tflog.Debug(ctx, "Provider AlzLib not present, beginning configuration")
var data gen.AlzModel
resp.Diagnostics.Append(req.Config.Get(ctx, &data)...)
if resp.Diagnostics.HasError() {
return
}
// Read the environment variables and set in data
// if the data is not already set and the environment variable is set.
configureFromEnvironment(&data)
// Set the go sdk's azidentity specific environment variables
configureAzIdentityEnvironment(&data)
// Configure aux tenant ids from config and environment.
if resp.Diagnostics = append(resp.Diagnostics, configureAuxTenants(ctx, &data)...); resp.Diagnostics.HasError() {
return
}
// Set the default values if not already set in the config or by environment.
configureDefaults(ctx, &data)
// Get a token credential.
cred, diags := getTokenCredential(data)
resp.Diagnostics = append(resp.Diagnostics, diags...)
if resp.Diagnostics.HasError() {
return
}
// Create the clients
clients, diags := getClients(cred, data, fmt.Sprintf("%s/%s", userAgentBase, p.version))
resp.Diagnostics = append(resp.Diagnostics, diags...)
if resp.Diagnostics.HasError() {
return
}
// Create the AlzLib.
alz, diags := configureAlzLib(cred, data, fmt.Sprintf("%s/%s", userAgentBase, p.version))
resp.Diagnostics = append(resp.Diagnostics, diags...)
if resp.Diagnostics.HasError() {
return
}
// Convert the supplied libraries to alzlib.LibraryReferences
libRefs, diags := generateLibraryDefinitions(ctx, &data)
resp.Diagnostics = append(resp.Diagnostics, diags...)
if resp.Diagnostics.HasError() {
return
}
// Fetch the library dependencies if enabled.
// If not, the refs passed to alzlib.Init() will be fetched on demand without dependencies.
if data.LibraryFetchDependencies.ValueBool() {
var err error
tflog.Debug(ctx, "Begin fetch library dependencies", map[string]interface{}{
"library_references": libRefs,
})
libRefs, err = libRefs.FetchWithDependencies(ctx)
if err != nil {
resp.Diagnostics.AddError("Failed to fetch library dependencies", err.Error())
return
}
tflog.Debug(ctx, "End fetch library dependencies", map[string]interface{}{
"library_references": libRefs,
})
}
// Init alzlib
if err := alz.Init(ctx, libRefs...); err != nil {
resp.Diagnostics.AddError("Failed to initialize AlzLib", err.Error())
return
}
// Store the alz pointer in the provider struct so we don't have to do all this work every time `.Configure` is called.
// Due to fetch from Azure, it takes approx 30 seconds each time and is called 4-5 time during a single acceptance test.
p.data = &alzProviderData{
AlzLib: alz,
mu: &sync.Mutex{},
clients: clients,
suppressWarningPolicyRoleAssignments: data.SuppressWarningPolicyRoleAssignments.ValueBool(),
}
resp.DataSourceData = p.data
resp.ResourceData = p.data
tflog.Debug(ctx, "Provider configuration finished")
}
func (p *AlzProvider) Resources(ctx context.Context) []func() resource.Resource {
return []func() resource.Resource{}
}
func (p *AlzProvider) DataSources(ctx context.Context) []func() datasource.DataSource {
return []func() datasource.DataSource{
NewArchitectureDataSource,
NewMetadataDataSource,
}
}
func (p *AlzProvider) Functions(ctx context.Context) []func() function.Function {
return []func() function.Function{}
}
func New(version string) func() provider.Provider {
return func() provider.Provider {
return &AlzProvider{
version: version,
}
}
}
func generateLibraryDefinitions(ctx context.Context, data *gen.AlzModel) (alzlib.LibraryReferences, diag.Diagnostics) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
var diags diag.Diagnostics
alzLibRefs := make([]gen.LibraryReferencesValue, len(data.LibraryReferences.Elements()))
diags = data.LibraryReferences.ElementsAs(ctx, &alzLibRefs, false)
if diags.HasError() {
return nil, diags
}
libRefs := make(alzlib.LibraryReferences, len(alzLibRefs))
for i, libRef := range alzLibRefs {
if libRef.CustomUrl.IsNull() {
libRefs[i] = alzlib.NewAlzLibraryReference(libRef.Path.ValueString(), libRef.Ref.ValueString())
continue
}
libRefs[i] = alzlib.NewCustomLibraryReference(libRef.CustomUrl.ValueString())
}
return libRefs, nil
}
func getFirstSetEnvVar(envVars ...string) string {
for _, envVar := range envVars {
if val := os.Getenv(envVar); val != "" {
return val
}
}
return ""
}
// configureAuxTenants gets a slice of the auxiliary tenant IDs from the provider data,
// or the environment variable `ARM_AUXILIARY_TENANT_IDS` if the provider data is not set.
func configureAuxTenants(ctx context.Context, data *gen.AlzModel) diag.Diagnostics {
var auxTenants []string
if data.AuxiliaryTenantIds.IsNull() {
if v := os.Getenv("ARM_AUXILIARY_TENANT_IDS"); v != "" {
auxTenants = strings.Split(v, ";")
}
var diags diag.Diagnostics
data.AuxiliaryTenantIds, diags = types.ListValueFrom(ctx, types.StringType, auxTenants)
return diags
}
return nil
}
// configureFromEnvironment sets the provider data from environment variables.
func configureFromEnvironment(data *gen.AlzModel) {
if val := getFirstSetEnvVar("ARM_CLIENT_CERTIFICATE_PASSWORD"); val != "" && data.ClientCertificatePassword.IsNull() {
data.ClientCertificatePassword = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_CLIENT_CERTIFICATE_PATH"); val != "" && data.ClientCertificatePath.IsNull() {
data.ClientCertificatePath = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_CLIENT_ID"); val != "" && data.ClientId.IsNull() {
data.ClientId = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_CLIENT_SECRET"); val != "" && data.ClientSecret.IsNull() {
data.ClientSecret = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_ENVIRONMENT"); val != "" && data.Environment.IsNull() {
data.Environment = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_OIDC_REQUEST_TOKEN", "ACTIONS_ID_TOKEN_REQUEST_TOKEN"); val != "" && data.OidcRequestToken.IsNull() {
data.OidcRequestToken = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_OIDC_REQUEST_URL", "ACTIONS_ID_TOKEN_REQUEST_URL"); val != "" && data.OidcRequestUrl.IsNull() {
data.OidcRequestUrl = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_OIDC_TOKEN"); val != "" && data.OidcToken.IsNull() {
data.OidcToken = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_OIDC_TOKEN_FILE_PATH"); val != "" && data.OidcTokenFilePath.IsNull() {
data.OidcTokenFilePath = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_TENANT_ID"); val != "" && data.TenantId.IsNull() {
data.TenantId = types.StringValue(val)
}
if val := getFirstSetEnvVar("ARM_USE_CLI"); val != "" && data.UseCli.IsNull() {
data.UseCli = types.BoolValue(str2Bool(val))
}
if val := getFirstSetEnvVar("ARM_USE_MSI"); val != "" && data.UseMsi.IsNull() {
data.UseMsi = types.BoolValue(str2Bool(val))
}
if val := getFirstSetEnvVar("ARM_USE_OIDC"); val != "" && data.UseOidc.IsNull() {
data.UseOidc = types.BoolValue(str2Bool(val))
}
if val := getFirstSetEnvVar("ARM_SKIP_PROVIDER_REGISTRATION"); val != "" && data.SkipProviderRegistration.IsNull() {
data.SkipProviderRegistration = types.BoolValue(str2Bool(val))
}
if val := getFirstSetEnvVar("ALZ_PROVIDER_SUPPRESS_WARNING_POLICY_ROLE_ASSIGNMENTS"); val != "" && data.SuppressWarningPolicyRoleAssignments.IsNull() {
data.SuppressWarningPolicyRoleAssignments = types.BoolValue(str2Bool(val))
}
}
// str2Bool converts a string to a bool, returning false if the string is not a valid bool.
func str2Bool(val string) bool {
b, err := strconv.ParseBool(val)
if err != nil {
b = false
}
return b
}
// configureAzIdentityEnvironment sets the environment variables used by go Azure sdk's azidentity package.
func configureAzIdentityEnvironment(data *gen.AlzModel) {
// Maps the auth related environment variables used in the provider to what azidentity honors.
if !data.TenantId.IsNull() {
// #nosec G104
os.Setenv("AZURE_TENANT_ID", data.TenantId.ValueString())
}
if !data.ClientId.IsNull() {
// #nosec G104
os.Setenv("AZURE_CLIENT_ID", data.ClientId.ValueString())
}
if !data.ClientSecret.IsNull() {
// #nosec G104
os.Setenv("AZURE_CLIENT_SECRET", data.ClientSecret.ValueString())
}
if !data.ClientCertificatePath.IsNull() {
// #nosec G104
os.Setenv("AZURE_CLIENT_CERTIFICATE_PATH", data.ClientCertificatePath.ValueString())
}
if !data.ClientCertificatePassword.IsNull() {
// #nosec G104
os.Setenv("AZURE_CLIENT_CERTIFICATE_PASSWORD", data.ClientCertificatePassword.ValueString())
}
if len(data.AuxiliaryTenantIds.Elements()) != 0 {
auxTenants := listElementsToStrings(data.AuxiliaryTenantIds.Elements())
// #nosec G104
os.Setenv("AZURE_ADDITIONALLY_ALLOWED_TENANTS", strings.Join(auxTenants, ";"))
}
}
// listElementsToStrings converts a list of attr.Value to a list of strings.
func listElementsToStrings(list []attr.Value) []string {
if len(list) == 0 {
return nil
}
strings := make([]string, len(list))
for i, v := range list {
sv, ok := v.(basetypes.StringValue)
if !ok {
return nil
}
strings[i] = sv.ValueString()
}
return strings
}
// configureAlzLib configures the alzlib for use by the provider.
func configureAlzLib(token *azidentity.ChainedTokenCredential, data gen.AlzModel, userAgent string) (*alzlib.AlzLib, diag.Diagnostics) {
var diags diag.Diagnostics
popts := new(policy.ClientOptions)
popts.DisableRPRegistration = data.SkipProviderRegistration.ValueBool()
popts.PerRetryPolicies = append(popts.PerRetryPolicies, withUserAgent(userAgent))
opts := &alzlib.AlzLibOptions{
AllowOverwrite: data.LibraryOverwriteEnabled.ValueBool(),
Parallelism: 10,
}
alz := alzlib.NewAlzLib(opts)
cf, err := armpolicy.NewClientFactory("", token, popts)
if err != nil {
diags.AddError("failed to create Azure Policy client factory: %v", err.Error())
return nil, diags
}
alz.AddPolicyClient(cf)
return alz, diags
}
func getClients(token *azidentity.ChainedTokenCredential, data gen.AlzModel, userAgent string) (*AlzProviderClients, diag.Diagnostics) {
var diags diag.Diagnostics
clients := new(AlzProviderClients)
popts := new(policy.ClientOptions)
popts.DisableRPRegistration = data.SkipProviderRegistration.ValueBool()
popts.PerRetryPolicies = append(popts.PerRetryPolicies, withUserAgent(userAgent))
client, err := armauthorization.NewRoleAssignmentsClient("", token, popts)
// Create the clients
//roleAssignmentsClient, err := newRoleAssignmentsClient(data)
if err != nil {
diags.AddError("failed to create Azure Role Assignments client: %v", err.Error())
return clients, diags
}
clients.RoleAssignmentsClient = client
return clients, diags
}
// getTokenCredential gets a token credential based on the provider data.
func getTokenCredential(data gen.AlzModel) (*azidentity.ChainedTokenCredential, diag.Diagnostics) {
var diags diag.Diagnostics
var cloudConfig cloud.Configuration
env := data.Environment.ValueString()
switch strings.ToLower(env) {
case "public":
cloudConfig = cloud.AzurePublic
case "usgovernment":
cloudConfig = cloud.AzureGovernment
case "china":
cloudConfig = cloud.AzureChina
default:
diags.AddError("Could not determine cloud configuration", "Valid values are 'public', 'usgovernment', or 'china'")
return nil, diags
}
auxTenants := listElementsToStrings(data.AuxiliaryTenantIds.Elements())
option := &azidentity.DefaultAzureCredentialOptions{
AdditionallyAllowedTenants: auxTenants,
ClientOptions: azcore.ClientOptions{
Cloud: cloudConfig,
},
TenantID: data.TenantId.ValueString(),
}
return newDefaultAzureCredential(data, option)
}
// configureDefaults sets default values if they aren't already set.
func configureDefaults(ctx context.Context, data *gen.AlzModel) {
// Use azure public cloud by default.
if data.Environment.IsNull() {
data.Environment = types.StringValue("public")
}
// Do not skip provider registration by default.
if data.SkipProviderRegistration.IsNull() {
data.SkipProviderRegistration = types.BoolValue(false)
}
// Do not use OIDC auth by default.
if data.UseOidc.IsNull() {
data.UseOidc = types.BoolValue(false)
}
// Do not use MSI auth by default.
if data.UseMsi.IsNull() {
data.UseMsi = types.BoolValue(false)
}
// Use CLI auth by default.
if data.UseCli.IsNull() {
data.UseCli = types.BoolValue(true)
}
// Do not allow library overwrite by default.
if data.LibraryOverwriteEnabled.IsNull() {
data.LibraryOverwriteEnabled = types.BoolValue(false)
}
// Automatically download dependencies by default.
if data.LibraryFetchDependencies.IsNull() {
data.LibraryFetchDependencies = types.BoolValue(true)
}
// Set alz library references to the default value if not already set.
if data.LibraryReferences.IsNull() {
element := gen.NewLibraryReferencesValueMust(
gen.NewLibraryReferencesValueNull().AttributeTypes(ctx),
map[string]attr.Value{
"ref": types.StringValue(alzLibRef),
"path": types.StringValue(alzLibPath),
"custom_url": types.StringNull(),
},
)
data.LibraryReferences = types.ListValueMust(element.Type(ctx), []attr.Value{element})
}
// Do not skip warning policy role assignments by default.
if data.SuppressWarningPolicyRoleAssignments.IsNull() {
data.SuppressWarningPolicyRoleAssignments = types.BoolValue(false)
}
}
func newDefaultAzureCredential(data gen.AlzModel, options *azidentity.DefaultAzureCredentialOptions) (*azidentity.ChainedTokenCredential, diag.Diagnostics) {
var creds []azcore.TokenCredential
var diags diag.Diagnostics
if options == nil {
options = &azidentity.DefaultAzureCredentialOptions{}
}
if data.UseOidc.ValueBool() {
oidcCred, err := NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: options.Cloud,
},
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
TenantID: data.TenantId.ValueString(),
ClientID: data.ClientId.ValueString(),
RequestToken: data.OidcRequestToken.ValueString(),
RequestUrl: data.OidcRequestUrl.ValueString(),
Token: data.OidcToken.ValueString(),
TokenFilePath: data.OidcTokenFilePath.ValueString(),
})
if err == nil {
creds = append(creds, oidcCred)
} else {
diags.AddWarning("newDefaultAzureCredential failed to initialize oidc credential:\n\t%s", err.Error())
}
}
envCred, err := azidentity.NewEnvironmentCredential(&azidentity.EnvironmentCredentialOptions{
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
})
if err == nil {
creds = append(creds, envCred)
} else {
diags.AddWarning("newDefaultAzureCredential failed to initialize environment credential:\n\t%s", err.Error())
}
if data.UseMsi.ValueBool() {
o := &azidentity.ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
if ID, ok := os.LookupEnv("AZURE_CLIENT_ID"); ok {
o.ID = azidentity.ClientID(ID)
}
miCred, err := newManagedIdentityCredential(o)
if err == nil {
creds = append(creds, miCred)
} else {
diags.AddWarning("newDefaultAzureCredential failed to initialize msi credential:\n\t%s", err.Error())
}
}
if data.UseCli.ValueBool() {
cliCred, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
TenantID: options.TenantID})
if err == nil {
creds = append(creds, cliCred)
} else {
diags.AddWarning("newDefaultAzureCredential failed to initialize cli credential:\n\t%s", err.Error())
}
}
if len(creds) == 0 {
diags.AddError("newDefaultAzureCredential failed to initialize any credential", "None of the credentials were initialized")
return nil, diags
}
chain, err := azidentity.NewChainedTokenCredential(creds, nil)
if err != nil {
diags.AddError("newDefaultAzureCredential failed to initialize chained credential:\n\t%s", err.Error())
return nil, diags
}
return chain, nil
}