internal/loader/configuration_client_manager.go (457 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package loader import ( "azappconfig/provider/internal/properties" "context" "fmt" "math" "math/rand" "net" "net/url" "os" "strconv" "strings" "time" acpv1 "azappconfig/provider/api/v1" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig" "github.com/google/uuid" authv1 "k8s.io/api/authentication/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" ctrlcfg "sigs.k8s.io/controller-runtime/pkg/client/config" ) //go:generate mockgen -destination=mocks/mock_configuration_client_manager.go -package mocks . ClientManager type ConfigurationClientManager struct { ReplicaDiscoveryEnabled bool LoadBalancingEnabled bool StaticClientWrappers []*ConfigurationClientWrapper DynamicClientWrappers []*ConfigurationClientWrapper validDomain string endpoint string credential azcore.TokenCredential secret string id string lastFallbackClientAttempt metav1.Time lastFallbackClientRefresh metav1.Time lastSuccessfulEndpoint string } type ConfigurationClientWrapper struct { Endpoint string Client *azappconfig.Client BackOffEndTime metav1.Time FailedAttempts int } type ClientManager interface { GetClients(ctx context.Context) ([]*ConfigurationClientWrapper, error) RefreshClients(ctx context.Context) } const ( TCP string = "tcp" Origin string = "origin" Alt string = "alt" EndpointSection string = "Endpoint" SecretSection string = "Secret" IdSection string = "Id" AzConfigDomainLabel string = ".azconfig." AppConfigDomainLabel string = ".appconfig." FallbackClientRefreshExpireInterval time.Duration = time.Hour MinimalClientRefreshInterval time.Duration = time.Second * 30 MaxBackoffDuration time.Duration = time.Minute * 10 MinBackoffDuration time.Duration = time.Second * 30 JitterRatio float64 = 0.25 SafeShiftLimit int = 63 ApiTokenExchangeAudience string = "api://AzureADTokenExchange" AnnotationClientID string = "azure.workload.identity/client-id" AnnotationTenantID string = "azure.workload.identity/tenant-id" ) var ( clientOptionWithModuleInfo *azappconfig.ClientOptions = &azappconfig.ClientOptions{ ClientOptions: policy.ClientOptions{ Telemetry: policy.TelemetryOptions{ ApplicationID: fmt.Sprintf("%s/%s", properties.ModuleName, properties.ModuleVersion), }, }, } ) func NewConfigurationClientManager(ctx context.Context, provider acpv1.AzureAppConfigurationProvider) (ClientManager, error) { manager := &ConfigurationClientManager{ ReplicaDiscoveryEnabled: provider.Spec.ReplicaDiscoveryEnabled, LoadBalancingEnabled: provider.Spec.LoadBalancingEnabled, lastSuccessfulEndpoint: "", } var err error var staticClient *azappconfig.Client if provider.Spec.ConnectionStringReference != nil { connectionString, err := getConnectionStringParameter(ctx, types.NamespacedName{Namespace: provider.Namespace, Name: *provider.Spec.ConnectionStringReference}) if err != nil { return nil, fmt.Errorf("fail to retrieve connection string from secret '%s': %s", *provider.Spec.ConnectionStringReference, err.Error()) } if manager.endpoint, err = parseConnectionString(connectionString, EndpointSection); err != nil { return nil, err } if err = verifyEndpointFromConnectionString(manager.endpoint); err != nil { return nil, err } if manager.secret, err = parseConnectionString(connectionString, SecretSection); err != nil { return nil, err } if manager.id, err = parseConnectionString(connectionString, IdSection); err != nil { return nil, err } if staticClient, err = azappconfig.NewClientFromConnectionString(connectionString, clientOptionWithModuleInfo); err != nil { return nil, err } } else { if manager.credential, err = CreateTokenCredential(ctx, provider.Spec.Auth, provider.Namespace); err != nil { return nil, err } if staticClient, err = azappconfig.NewClient(*provider.Spec.Endpoint, manager.credential, clientOptionWithModuleInfo); err != nil { return nil, err } manager.endpoint = *provider.Spec.Endpoint } manager.validDomain = getValidDomain(manager.endpoint) manager.StaticClientWrappers = []*ConfigurationClientWrapper{{ Endpoint: manager.endpoint, Client: staticClient, BackOffEndTime: metav1.Time{}, FailedAttempts: 0, }} return manager, nil } func (manager *ConfigurationClientManager) GetClients(ctx context.Context) ([]*ConfigurationClientWrapper, error) { currentTime := metav1.Now() clients := make([]*ConfigurationClientWrapper, 0) for _, clientWrapper := range manager.StaticClientWrappers { if currentTime.After(clientWrapper.BackOffEndTime.Time) { clients = append(clients, clientWrapper) } } if !manager.ReplicaDiscoveryEnabled { return clients, nil } if currentTime.After(manager.lastFallbackClientAttempt.Time.Add(MinimalClientRefreshInterval)) && (manager.DynamicClientWrappers == nil || currentTime.After(manager.lastFallbackClientRefresh.Time.Add(FallbackClientRefreshExpireInterval))) { manager.lastFallbackClientAttempt = currentTime url, _ := url.Parse(manager.endpoint) manager.DiscoverFallbackClients(ctx, url.Host) } for _, clientWrapper := range manager.DynamicClientWrappers { if currentTime.After(clientWrapper.BackOffEndTime.Time) { clients = append(clients, clientWrapper) } } return clients, nil } func (manager *ConfigurationClientManager) RefreshClients(ctx context.Context) { currentTime := metav1.Now() if manager.ReplicaDiscoveryEnabled && currentTime.After(manager.lastFallbackClientAttempt.Time.Add(MinimalClientRefreshInterval)) { manager.lastFallbackClientAttempt = currentTime url, _ := url.Parse(manager.endpoint) manager.DiscoverFallbackClients(ctx, url.Host) } } func (manager *ConfigurationClientManager) DiscoverFallbackClients(ctx context.Context, host string) { newCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() resultChan := make(chan []string) errChan := make(chan error) go func() { srvTargetHosts, err := QuerySrvTargetHost(newCtx, host) if err != nil { errChan <- err } else { resultChan <- srvTargetHosts } close(resultChan) close(errChan) }() select { case <-newCtx.Done(): klog.Warningf("fail to build fallback clients, SRV DNS lookup is timeout") break case err := <-errChan: klog.Warningf("fail to build fallback clients %s", err.Error()) break case srvTargetHosts := <-resultChan: // Shuffle the list of SRV target hosts for i := range srvTargetHosts { j := rand.Intn(i + 1) srvTargetHosts[i], srvTargetHosts[j] = srvTargetHosts[j], srvTargetHosts[i] } newDynamicClients := make([]*ConfigurationClientWrapper, 0) for _, host := range srvTargetHosts { if isValidEndpoint(host, manager.validDomain) { targetEndpoint := "https://" + host if strings.EqualFold(targetEndpoint, manager.endpoint) { continue } client, err := manager.newConfigurationClient(targetEndpoint) if err != nil { klog.Warningf("build fallback clients failed, %s", err.Error()) return } newDynamicClients = append(newDynamicClients, &ConfigurationClientWrapper{ Endpoint: targetEndpoint, Client: client, BackOffEndTime: metav1.Time{}, FailedAttempts: 0, }) } } manager.DynamicClientWrappers = newDynamicClients manager.lastFallbackClientRefresh = metav1.Now() break } } func QuerySrvTargetHost(ctx context.Context, host string) ([]string, error) { results := make([]string, 0) _, originRecords, err := net.DefaultResolver.LookupSRV(ctx, Origin, TCP, host) if err != nil { // If the host does not have SRV records => no replicas if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound { return results, nil } else { return results, err } } if len(originRecords) == 0 { return results, nil } originHost := strings.TrimSuffix(originRecords[0].Target, ".") results = append(results, originHost) index := 0 for { currentAlt := Alt + strconv.Itoa(index) _, altRecords, err := net.DefaultResolver.LookupSRV(ctx, currentAlt, TCP, originHost) if err != nil { // If the host does not have SRV records => no more replicas if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound { break } else { return results, err } } for _, record := range altRecords { altHost := strings.TrimSuffix(record.Target, ".") if altHost != "" { results = append(results, altHost) } } index = index + 1 } return results, nil } func (manager *ConfigurationClientManager) newConfigurationClient(endpoint string) (*azappconfig.Client, error) { if manager.credential != nil { return azappconfig.NewClient(endpoint, manager.credential, clientOptionWithModuleInfo) } connectionStr := buildConnectionString(endpoint, manager.secret, manager.id) if connectionStr == "" { return nil, fmt.Errorf("failed to build connection string for fallback client") } return azappconfig.NewClientFromConnectionString(connectionStr, clientOptionWithModuleInfo) } func isValidEndpoint(host string, validDomain string) bool { if validDomain == "" { return false } return strings.HasSuffix(strings.ToLower(host), strings.ToLower(validDomain)) } func getValidDomain(endpoint string) string { url, _ := url.Parse(endpoint) TrustedDomainLabels := []string{AzConfigDomainLabel, AppConfigDomainLabel} for _, label := range TrustedDomainLabels { index := strings.LastIndex(strings.ToLower(url.Host), strings.ToLower(label)) if index != -1 { return url.Host[index:] } } return "" } func buildConnectionString(endpoint string, secret string, id string) string { if secret == "" || id == "" { return "" } return fmt.Sprintf("%s=%s;%s=%s;%s=%s", EndpointSection, endpoint, IdSection, id, SecretSection, secret) } func parseConnectionString(connectionString string, token string) (string, error) { if connectionString == "" { return "", fmt.Errorf("connectionString is empty") } parseToken := token + "=" startIndex := strings.Index(connectionString, parseToken) if startIndex < 0 { return "", fmt.Errorf("invalid connectionString %s", connectionString) } endIndex := strings.Index(connectionString[startIndex:], ";") if endIndex < 0 { endIndex = len(connectionString) } else { endIndex += startIndex } return connectionString[startIndex+len(parseToken) : endIndex], nil } func verifyEndpointFromConnectionString(endpoint string) error { url, err := url.Parse(strings.ToLower(endpoint)) if err != nil { return fmt.Errorf("invalid endpoint %q from connectionString", endpoint) } if url.Host == "" { return fmt.Errorf("invalid endpoint %q from connectionString, host must be specified", endpoint) } if url.Scheme != "https" { return fmt.Errorf("invalid endpoint %q from connectionString, only https scheme is allowed", endpoint) } if strings.Trim(url.Path, "/") != "" { return fmt.Errorf("invalid endpoint %q from connectionString, only host name is allowed", endpoint) } return nil } func CreateTokenCredential(ctx context.Context, acpAuth *acpv1.AzureAppConfigurationProviderAuth, namespace string) (azcore.TokenCredential, error) { // If User explicitly specify the authentication method if acpAuth != nil { if acpAuth.WorkloadIdentity != nil { if acpAuth.WorkloadIdentity.ServiceAccountName != nil { return newClientAssertionCredential(ctx, *acpAuth.WorkloadIdentity.ServiceAccountName, namespace) } workloadIdentityClientId, err := getWorkloadIdentityClientId(ctx, acpAuth.WorkloadIdentity, namespace) if err != nil { return nil, fmt.Errorf("fail to retrieve workload identity client ID from configMap '%s' : %s", acpAuth.WorkloadIdentity.ManagedIdentityClientIdReference.ConfigMap, err.Error()) } return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ ClientID: workloadIdentityClientId, }) } if acpAuth.ServicePrincipalReference != nil { parameter, err := getServicePrincipleAuthenticationParameters(ctx, types.NamespacedName{Namespace: namespace, Name: *acpAuth.ServicePrincipalReference}) if err != nil { return nil, fmt.Errorf("fail to retrieve service principal secret from '%s': %s", *acpAuth.ServicePrincipalReference, err.Error()) } return azidentity.NewClientSecretCredential(parameter.TenantId, parameter.ClientId, parameter.ClientSecret, nil) } if acpAuth.ManagedIdentityClientId != nil { return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ ID: azidentity.ClientID(*acpAuth.ManagedIdentityClientId), }) } } else { return azidentity.NewManagedIdentityCredential(nil) } return nil, nil } func getWorkloadIdentityClientId(ctx context.Context, workloadIdentityAuth *acpv1.WorkloadIdentityParameters, namespace string) (string, error) { if workloadIdentityAuth.ManagedIdentityClientIdReference == nil { return *workloadIdentityAuth.ManagedIdentityClientId, nil } else { configMap, err := getConfigMap(ctx, types.NamespacedName{Namespace: namespace, Name: workloadIdentityAuth.ManagedIdentityClientIdReference.ConfigMap}) if err != nil { return "", err } if _, ok := configMap.Data[workloadIdentityAuth.ManagedIdentityClientIdReference.Key]; !ok { return "", fmt.Errorf("key '%s' does not exist", workloadIdentityAuth.ManagedIdentityClientIdReference.Key) } managedIdentityClientId := configMap.Data[workloadIdentityAuth.ManagedIdentityClientIdReference.Key] if _, err = uuid.Parse(managedIdentityClientId); err != nil { return "", fmt.Errorf("managedIdentityClientId %q is not a valid uuid", managedIdentityClientId) } return managedIdentityClientId, nil } } func getConnectionStringParameter(ctx context.Context, namespacedSecretName types.NamespacedName) (string, error) { secret, err := GetSecret(ctx, namespacedSecretName) if err != nil { return "", err } if _, ok := secret.Data[AzureAppConfigurationConnectionString]; !ok { return "", fmt.Errorf("key '%s' does not exist", AzureAppConfigurationConnectionString) } return string(secret.Data[AzureAppConfigurationConnectionString]), nil } func getServicePrincipleAuthenticationParameters(ctx context.Context, namespacedSecretName types.NamespacedName) (*ServicePrincipleAuthenticationParameters, error) { secret, err := GetSecret(ctx, namespacedSecretName) if err != nil { return nil, err } return &ServicePrincipleAuthenticationParameters{ ClientId: string(secret.Data[AzureClientId]), ClientSecret: string(secret.Data[AzureClientSecret]), TenantId: string(secret.Data[AzureTenantId]), }, nil } func calculateBackoffDuration(failedAttempts int) time.Duration { if failedAttempts <= 1 { return MinBackoffDuration } calculatedMilliseconds := math.Max(1, float64(MinBackoffDuration.Milliseconds())) * math.Pow(2, math.Min(float64(failedAttempts-1), float64(SafeShiftLimit))) if calculatedMilliseconds > float64(MaxBackoffDuration.Milliseconds()) || calculatedMilliseconds <= 0 { calculatedMilliseconds = float64(MaxBackoffDuration.Milliseconds()) } calculatedDuration := time.Duration(calculatedMilliseconds) * time.Millisecond return Jitter(calculatedDuration) } func Jitter(duration time.Duration) time.Duration { // Calculate the amount of jitter to add to the duration jitter := float64(duration) * JitterRatio // Generate a random number between -jitter and +jitter randomJitter := rand.Float64()*(2*jitter) - jitter // Apply the random jitter to the original duration return duration + time.Duration(randomJitter) } func newClientAssertionCredential(ctx context.Context, serviceAccountName string, serviceAccountNamespace string) (azcore.TokenCredential, error) { cfg, err := ctrlcfg.GetConfig() if err != nil { return nil, err } client, err := client.New(cfg, client.Options{}) if err != nil { return nil, err } serviceAccountObj := &corev1.ServiceAccount{} err = client.Get(ctx, types.NamespacedName{Namespace: serviceAccountNamespace, Name: serviceAccountName}, serviceAccountObj) if err != nil { return nil, err } if _, ok := serviceAccountObj.Annotations[AnnotationClientID]; !ok { return nil, fmt.Errorf("annotation '%s' of service account %s/%s is required", AnnotationClientID, serviceAccountNamespace, serviceAccountName) } tenantId := "" if _, ok := serviceAccountObj.Annotations[AnnotationTenantID]; ok { tenantId = serviceAccountObj.Annotations[AnnotationTenantID] } else if _, ok := os.LookupEnv(strings.ToUpper(AzureTenantId)); ok { tenantId = os.Getenv(strings.ToUpper(AzureTenantId)) } else { return nil, fmt.Errorf("annotation '%s' of service account %s/%s is required since using global service account for workload identity is disabled", AnnotationTenantID, serviceAccountNamespace, serviceAccountName) } getAssertionFunc := newGetAssertionFunc(serviceAccountNamespace, serviceAccountName) clientAssertionCredential, err := azidentity.NewClientAssertionCredential(tenantId, serviceAccountObj.Annotations[AnnotationClientID], getAssertionFunc, nil) if err != nil { return nil, err } return clientAssertionCredential, nil } func newGetAssertionFunc(serviceAccountNamespace string, serviceAccountName string) func(ctx context.Context) (string, error) { audiences := []string{ApiTokenExchangeAudience} return func(ctx context.Context) (string, error) { cfg, err := ctrlcfg.GetConfig() if err != nil { return "", err } kubeClient, err := kubernetes.NewForConfig(cfg) if err != nil { return "", err } token, err := kubeClient.CoreV1().ServiceAccounts(serviceAccountNamespace).CreateToken(ctx, serviceAccountName, &authv1.TokenRequest{ Spec: authv1.TokenRequestSpec{ Audiences: audiences, }, }, metav1.CreateOptions{}) if err != nil { return "", err } return token.Status.Token, nil } }