azext_edge/edge/util/az_client.py (203 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- import sys from time import sleep from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple from azure.cli.core.azclierror import ValidationError from knack.log import get_logger from ...constants import USER_AGENT from .common import ensure_azure_namespace_path if sys.version_info >= (3, 9): from collections.abc import MutableMapping else: from typing import MutableMapping JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object ensure_azure_namespace_path() from azure.core.pipeline.policies import HttpLoggingPolicy, UserAgentPolicy from azure.identity import AzureCliCredential AZURE_CLI_CREDENTIAL = AzureCliCredential() POLL_RETRIES = 240 POLL_WAIT_SEC = 15 logger = get_logger(__name__) if TYPE_CHECKING: from azure.core.polling import LROPoller from ..vendor.clients.authzmgmt import AuthorizationManagementClient from ..vendor.clients.clusterconfigmgmt import KubernetesConfigurationClient from ..vendor.clients.connectedclustermgmt import ConnectedKubernetesClient from ..vendor.clients.deviceregistrymgmt import ( MicrosoftDeviceRegistryManagementService, ) from ..vendor.clients.iotopsmgmt import MicrosoftIoTOperationsManagementService from ..vendor.clients.resourcesmgmt import ResourceManagementClient from ..vendor.clients.storagemgmt import StorageManagementClient from ..vendor.clients.msimgmt import ManagedServiceIdentityClient from ..vendor.clients.secretsyncmgmt import MicrosoftSecretSyncController from ..vendor.clients.keyvault import KeyVaultClient from ..vendor.clients.extendedlocmgmt import CustomLocations # TODO @digimaun - simplify client init pattern. Consider multi-profile vs static API client. def get_extloc_mgmt_client(subscription_id: str, **kwargs) -> "CustomLocations": from ..vendor.clients.extendedlocmgmt import CustomLocations if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return CustomLocations( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_ssc_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftSecretSyncController": from ..vendor.clients.secretsyncmgmt import MicrosoftSecretSyncController if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return MicrosoftSecretSyncController( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_msi_mgmt_client(subscription_id: str, **kwargs) -> "ManagedServiceIdentityClient": from ..vendor.clients.msimgmt import ManagedServiceIdentityClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return ManagedServiceIdentityClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_clusterconfig_mgmt_client(subscription_id: str, **kwargs) -> "KubernetesConfigurationClient": from ..vendor.clients.clusterconfigmgmt import KubernetesConfigurationClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return KubernetesConfigurationClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_connectedk8s_mgmt_client(subscription_id: str, **kwargs) -> "ConnectedKubernetesClient": from ..vendor.clients.connectedclustermgmt import ConnectedKubernetesClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return ConnectedKubernetesClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_storage_mgmt_client(subscription_id: str, **kwargs) -> "StorageManagementClient": from ..vendor.clients.storagemgmt import StorageManagementClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return StorageManagementClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) REGISTRY_PREVIEW_API_VERSION = "2024-09-01-preview" REGISTRY_API_VERSION = "2024-11-01" def get_registry_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftDeviceRegistryManagementService": from ..vendor.clients.deviceregistrymgmt import ( MicrosoftDeviceRegistryManagementService, ) if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return MicrosoftDeviceRegistryManagementService( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_iotops_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftIoTOperationsManagementService": from ..vendor.clients.iotopsmgmt import MicrosoftIoTOperationsManagementService if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return MicrosoftIoTOperationsManagementService( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_resource_client(subscription_id: str, **kwargs) -> "ResourceManagementClient": from ..vendor.clients.resourcesmgmt import ResourceManagementClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return ResourceManagementClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_authz_client(subscription_id: str, **kwargs) -> "AuthorizationManagementClient": from ..vendor.clients.authzmgmt import AuthorizationManagementClient if "http_logging_policy" not in kwargs: kwargs["http_logging_policy"] = get_default_logging_policy() return AuthorizationManagementClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), **kwargs, ) def get_keyvault_client(subscription_id: str, **kwargs) -> "KeyVaultClient": from ..vendor.clients.keyvault import KeyVaultClient # TODO: this only supports azure public cloud for now client = KeyVaultClient( credential=AZURE_CLI_CREDENTIAL, subscription_id=subscription_id, user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT), credential_scopes=["https://vault.azure.net/.default"], **kwargs, ) return client def wait_for_terminal_state(poller: "LROPoller", wait_sec: int = POLL_WAIT_SEC, **_) -> JSON: # resource client does not handle sigint well counter = 0 while counter < POLL_RETRIES: sleep(wait_sec) counter = counter + 1 if poller.done(): break return poller.result() def wait_for_terminal_states( *pollers: "LROPoller", retries: int = POLL_RETRIES, wait_sec: int = POLL_WAIT_SEC, **_ ) -> Tuple["LROPoller"]: counter = 0 while counter < retries: sleep(wait_sec) counter = counter + 1 batch_done = all(poller.done() for poller in pollers) if batch_done: break return pollers def get_tenant_id() -> str: from azure.cli.core._profile import Profile profile = Profile() sub = profile.get_subscription() return sub["tenantId"] def get_default_logging_policy() -> HttpLoggingPolicy: http_logging_policy = HttpLoggingPolicy(logger=logger) http_logging_policy.allowed_query_params.add("api-version") http_logging_policy.allowed_query_params.add("$filter") http_logging_policy.allowed_query_params.add("$expand") http_logging_policy.allowed_header_names.add("x-ms-correlation-request-id") return http_logging_policy class ResourceIdContainer(NamedTuple): subscription_id: str resource_group_name: str resource_name: str resource_id: str def parse_resource_id(resource_id: str) -> Optional[ResourceIdContainer]: if not resource_id: return resource_id # TODO - cheap. parts = resource_id.split("/") if len(parts) < 9: raise ValidationError( f"Malformed resource Id '{resource_id}'. An Azure resource Id has the form:\n" "/subscription/{subscriptionId}/resourceGroups/{resourceGroup}" "/providers/Microsoft.Provider/{resourcePath}/{resourceName}" ) # Extract the subscription, resource group, and resource name subscription_id = parts[2] resource_group_name = parts[4] resource_name = parts[-1] return ResourceIdContainer( subscription_id=subscription_id, resource_group_name=resource_group_name, resource_name=resource_name, resource_id=resource_id, )