azext_edge/edge/util/az_client.py (203 lines of code) (raw):
# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------
import sys
from time import sleep
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
from azure.cli.core.azclierror import ValidationError
from knack.log import get_logger
from ...constants import USER_AGENT
from .common import ensure_azure_namespace_path
if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object
ensure_azure_namespace_path()
from azure.core.pipeline.policies import HttpLoggingPolicy, UserAgentPolicy
from azure.identity import AzureCliCredential
AZURE_CLI_CREDENTIAL = AzureCliCredential()
POLL_RETRIES = 240
POLL_WAIT_SEC = 15
logger = get_logger(__name__)
if TYPE_CHECKING:
from azure.core.polling import LROPoller
from ..vendor.clients.authzmgmt import AuthorizationManagementClient
from ..vendor.clients.clusterconfigmgmt import KubernetesConfigurationClient
from ..vendor.clients.connectedclustermgmt import ConnectedKubernetesClient
from ..vendor.clients.deviceregistrymgmt import (
MicrosoftDeviceRegistryManagementService,
)
from ..vendor.clients.iotopsmgmt import MicrosoftIoTOperationsManagementService
from ..vendor.clients.resourcesmgmt import ResourceManagementClient
from ..vendor.clients.storagemgmt import StorageManagementClient
from ..vendor.clients.msimgmt import ManagedServiceIdentityClient
from ..vendor.clients.secretsyncmgmt import MicrosoftSecretSyncController
from ..vendor.clients.keyvault import KeyVaultClient
from ..vendor.clients.extendedlocmgmt import CustomLocations
# TODO @digimaun - simplify client init pattern. Consider multi-profile vs static API client.
def get_extloc_mgmt_client(subscription_id: str, **kwargs) -> "CustomLocations":
from ..vendor.clients.extendedlocmgmt import CustomLocations
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return CustomLocations(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_ssc_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftSecretSyncController":
from ..vendor.clients.secretsyncmgmt import MicrosoftSecretSyncController
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return MicrosoftSecretSyncController(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_msi_mgmt_client(subscription_id: str, **kwargs) -> "ManagedServiceIdentityClient":
from ..vendor.clients.msimgmt import ManagedServiceIdentityClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return ManagedServiceIdentityClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_clusterconfig_mgmt_client(subscription_id: str, **kwargs) -> "KubernetesConfigurationClient":
from ..vendor.clients.clusterconfigmgmt import KubernetesConfigurationClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return KubernetesConfigurationClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_connectedk8s_mgmt_client(subscription_id: str, **kwargs) -> "ConnectedKubernetesClient":
from ..vendor.clients.connectedclustermgmt import ConnectedKubernetesClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return ConnectedKubernetesClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_storage_mgmt_client(subscription_id: str, **kwargs) -> "StorageManagementClient":
from ..vendor.clients.storagemgmt import StorageManagementClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return StorageManagementClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
REGISTRY_PREVIEW_API_VERSION = "2024-09-01-preview"
REGISTRY_API_VERSION = "2024-11-01"
def get_registry_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftDeviceRegistryManagementService":
from ..vendor.clients.deviceregistrymgmt import (
MicrosoftDeviceRegistryManagementService,
)
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return MicrosoftDeviceRegistryManagementService(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_iotops_mgmt_client(subscription_id: str, **kwargs) -> "MicrosoftIoTOperationsManagementService":
from ..vendor.clients.iotopsmgmt import MicrosoftIoTOperationsManagementService
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return MicrosoftIoTOperationsManagementService(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_resource_client(subscription_id: str, **kwargs) -> "ResourceManagementClient":
from ..vendor.clients.resourcesmgmt import ResourceManagementClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return ResourceManagementClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_authz_client(subscription_id: str, **kwargs) -> "AuthorizationManagementClient":
from ..vendor.clients.authzmgmt import AuthorizationManagementClient
if "http_logging_policy" not in kwargs:
kwargs["http_logging_policy"] = get_default_logging_policy()
return AuthorizationManagementClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
**kwargs,
)
def get_keyvault_client(subscription_id: str, **kwargs) -> "KeyVaultClient":
from ..vendor.clients.keyvault import KeyVaultClient
# TODO: this only supports azure public cloud for now
client = KeyVaultClient(
credential=AZURE_CLI_CREDENTIAL,
subscription_id=subscription_id,
user_agent_policy=UserAgentPolicy(user_agent=USER_AGENT),
credential_scopes=["https://vault.azure.net/.default"],
**kwargs,
)
return client
def wait_for_terminal_state(poller: "LROPoller", wait_sec: int = POLL_WAIT_SEC, **_) -> JSON:
# resource client does not handle sigint well
counter = 0
while counter < POLL_RETRIES:
sleep(wait_sec)
counter = counter + 1
if poller.done():
break
return poller.result()
def wait_for_terminal_states(
*pollers: "LROPoller", retries: int = POLL_RETRIES, wait_sec: int = POLL_WAIT_SEC, **_
) -> Tuple["LROPoller"]:
counter = 0
while counter < retries:
sleep(wait_sec)
counter = counter + 1
batch_done = all(poller.done() for poller in pollers)
if batch_done:
break
return pollers
def get_tenant_id() -> str:
from azure.cli.core._profile import Profile
profile = Profile()
sub = profile.get_subscription()
return sub["tenantId"]
def get_default_logging_policy() -> HttpLoggingPolicy:
http_logging_policy = HttpLoggingPolicy(logger=logger)
http_logging_policy.allowed_query_params.add("api-version")
http_logging_policy.allowed_query_params.add("$filter")
http_logging_policy.allowed_query_params.add("$expand")
http_logging_policy.allowed_header_names.add("x-ms-correlation-request-id")
return http_logging_policy
class ResourceIdContainer(NamedTuple):
subscription_id: str
resource_group_name: str
resource_name: str
resource_id: str
def parse_resource_id(resource_id: str) -> Optional[ResourceIdContainer]:
if not resource_id:
return resource_id
# TODO - cheap.
parts = resource_id.split("/")
if len(parts) < 9:
raise ValidationError(
f"Malformed resource Id '{resource_id}'. An Azure resource Id has the form:\n"
"/subscription/{subscriptionId}/resourceGroups/{resourceGroup}"
"/providers/Microsoft.Provider/{resourcePath}/{resourceName}"
)
# Extract the subscription, resource group, and resource name
subscription_id = parts[2]
resource_group_name = parts[4]
resource_name = parts[-1]
return ResourceIdContainer(
subscription_id=subscription_id,
resource_group_name=resource_group_name,
resource_name=resource_name,
resource_id=resource_id,
)