azext_iot/digitaltwins/providers/resource.py (714 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from functools import partial
from typing import List, Optional
from azure.cli.core.azclierror import (
ArgumentUsageError,
InvalidArgumentValueError,
RequiredArgumentMissingError,
ResourceNotFoundError,
MutuallyExclusiveArgumentError
)
from azext_iot.digitaltwins.common import (
ADX_DEFAULT_TABLE,
DEFAULT_CONSUMER_GROUP,
LRO_TIMER,
MAX_ADT_DH_CREATE_RETRIES,
SYSTEM_IDENTITY,
ADTEndpointAuthType,
ADTPublicNetworkAccessType,
MAX_ADT_CREATE_RETRIES,
IdentityType,
)
from azext_iot.digitaltwins.providers import (
DigitalTwinsResourceManager,
CloudError,
ErrorResponseException,
)
from azext_iot.digitaltwins.providers.generic import generic_check_state
from azext_iot.digitaltwins.providers.rbac import RbacProvider
from azext_iot.sdk.digitaltwins.controlplane.models import (
DigitalTwinsDescription,
)
from azext_iot.common.utility import (
handle_service_exception
)
from knack.log import get_logger
from azext_iot.sdk.digitaltwins.controlplane.models import DigitalTwinsIdentity
logger = get_logger(__name__)
class ResourceProvider(DigitalTwinsResourceManager):
def __init__(self, cmd):
super(ResourceProvider, self).__init__(cmd=cmd)
self.mgmt_sdk = self.get_mgmt_sdk()
self.rbac = RbacProvider(cmd.cli_ctx)
def create(
self,
name: str,
resource_group_name: str,
location: Optional[str] = None,
tags: Optional[str] = None,
timeout: int = 60,
assign_identity: Optional[str] = None,
scopes: Optional[List[str]] = None,
role_type: str = "Contributor",
public_network_access: str = ADTPublicNetworkAccessType.enabled.value,
system_identity: Optional[str] = None,
user_identities: Optional[str] = None
):
if not location:
from azext_iot.common.embedded_cli import EmbeddedCLI
resource_group_meta = (
EmbeddedCLI(cli_ctx=self.cmd.cli_ctx)
.invoke("group show --name {}".format(resource_group_name))
.as_json()
)
location = resource_group_meta["location"]
try:
# Temporary make the deprecated assign_identity param work for now
if assign_identity and not system_identity:
system_identity = assign_identity
if system_identity:
if scopes and not role_type:
raise RequiredArgumentMissingError(
"Both --scopes and --role values are required when assigning the instance identity."
)
identity = {"type": IdentityType.none.value}
if system_identity and user_identities:
identity["type"] = IdentityType.system_assigned_user_assigned.value
elif user_identities:
identity["type"] = IdentityType.user_assigned.value
elif system_identity:
identity["type"] = IdentityType.system_assigned.value
if user_identities:
identity["userAssignedIdentities"] = {}
for user_identity in user_identities:
identity["userAssignedIdentities"][user_identity] = {}
digital_twins_create = DigitalTwinsDescription(
location=location,
tags=tags,
identity=identity,
public_network_access=public_network_access,
)
create_or_update = self.mgmt_sdk.digital_twins.create_or_update(
resource_name=name,
resource_group_name=resource_group_name,
digital_twins_create=digital_twins_create,
long_running_operation_timeout=timeout,
)
def check_state(lro):
generic_check_state(
lro=lro,
show_cmd=f"az dt show -n {name} -g {resource_group_name}",
max_retries=MAX_ADT_CREATE_RETRIES
)
rbac_handler = partial(self._rbac_handler, scopes=scopes, role_type=role_type)
create_or_update.add_done_callback(check_state)
create_or_update.add_done_callback(rbac_handler)
return create_or_update
except CloudError as e:
raise e
except ErrorResponseException as err:
handle_service_exception(err)
def _rbac_handler(self, lro, scopes: List[str], role_type: str):
instance = lro.resource().as_dict()
identity = instance.get("identity")
if identity:
identity_type = identity.get("type")
principal_id = identity.get("principal_id")
if (
principal_id
and scopes
and identity_type
and identity_type.lower() == "systemassigned"
):
for scope in scopes:
logger.info(
"Applying rbac assignment: Principal Id: {}, Scope: {}, Role: {}".format(
principal_id, scope, role_type
)
)
self.rbac.assign_role_flex(
principal_id=principal_id,
scope=scope,
role_type=role_type,
)
def list(self):
try:
return self.mgmt_sdk.digital_twins.list()
except ErrorResponseException as e:
handle_service_exception(e)
def list_by_resouce_group(self, resource_group_name: str):
try:
return self.mgmt_sdk.digital_twins.list_by_resource_group(
resource_group_name=resource_group_name
)
except ErrorResponseException as e:
handle_service_exception(e)
def get(self, name: str, resource_group_name: str, wait: bool = False):
try:
return self.mgmt_sdk.digital_twins.get(
resource_name=name, resource_group_name=resource_group_name
)
except ErrorResponseException as e:
handle_service_exception(e)
def find_instance(self, name: str, resource_group_name: Optional[str] = None, wait: bool = False):
if resource_group_name:
return self.get(
name=name, resource_group_name=resource_group_name, wait=wait
)
dt_collection_pager = self.list()
dt_collection = []
try:
while True:
dt_collection.extend(dt_collection_pager.advance_page())
except StopIteration:
pass
compare_name = name.lower()
filter_result = [
instance
for instance in dt_collection
if instance.name.lower() == compare_name
]
if filter_result:
if len(filter_result) > 1:
raise ArgumentUsageError(
"Ambiguous DT instance name. Please include the DT instance resource group."
)
return filter_result[0]
raise ResourceNotFoundError(
"DT instance: '{}' not found by auto-discovery. "
"Provide resource group via -g for direct lookup.".format(name)
)
def get_rg(self, dt_instance):
dt_scope = dt_instance.id
split_decomp = dt_scope.split("/")
res_g = split_decomp[4]
return res_g
def delete(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.digital_twins.delete(
resource_name=name,
resource_group_name=resource_group_name,
)
except ErrorResponseException as e:
handle_service_exception(e)
# RBAC
def get_role_assignments(
self,
name: str,
include_inherited: bool = False,
role_type: Optional[str] = None,
resource_group_name: Optional[str] = None
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
return self.rbac.list_assignments(
dt_scope=target_instance.id,
include_inherited=include_inherited,
role_type=role_type,
)
def assign_role(self, name: str, role_type: str, assignee: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
return self.rbac.assign_role(
dt_scope=target_instance.id, assignee=assignee, role_type=role_type
)
def remove_role(
self, name: str, assignee: str, role_type: Optional[str] = None, resource_group_name: Optional[str] = None
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
return self.rbac.remove_role(
dt_scope=target_instance.id, assignee=assignee, role_type=role_type
)
# Identity
def assign_identity(
self,
name: str,
system_identity: Optional[bool] = None,
user_identities: Optional[List[str]] = None,
identity_role: Optional[str] = None,
identity_scopes: Optional[List[str]] = None,
resource_group_name: Optional[str] = None
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
if bool(identity_role) ^ bool(identity_scopes):
raise RequiredArgumentMissingError(
'At least one scope (--scopes) and one role (--role) required for system-managed identity role assignment.'
)
if not system_identity and not user_identities:
raise RequiredArgumentMissingError(
'No identities provided to assign. Please provide system (--system) or user-assigned identities (--user).'
)
if not target_instance.identity:
target_instance.identity = DigitalTwinsIdentity()
if user_identities:
if not target_instance.identity.user_assigned_identities:
target_instance.identity.user_assigned_identities = {}
for user_identity in user_identities:
identity = target_instance.identity.user_assigned_identities.get(user_identity, {})
target_instance.identity.user_assigned_identities[user_identity] = identity
has_system_identity = target_instance.identity.type in [
IdentityType.system_assigned_user_assigned.value, IdentityType.system_assigned.value
]
if system_identity or has_system_identity:
if target_instance.identity.user_assigned_identities:
target_instance.identity.type = IdentityType.system_assigned_user_assigned.value
else:
target_instance.identity.type = IdentityType.system_assigned.value
else:
if target_instance.identity.user_assigned_identities:
target_instance.identity.type = IdentityType.user_assigned.value
else:
target_instance.identity.type = IdentityType.none.value
try:
update_poller = self.mgmt_sdk.digital_twins.create_or_update(
resource_name=name,
resource_group_name=resource_group_name,
digital_twins_create=target_instance,
long_running_operation_timeout=LRO_TIMER,
)
rbac_handler = partial(self._rbac_handler, scopes=identity_scopes, role_type=identity_role)
update_poller.add_done_callback(rbac_handler)
return update_poller
except CloudError as e:
raise e
except ErrorResponseException as err:
handle_service_exception(err)
def remove_identity(
self,
name: str,
system_identity: Optional[bool] = None,
user_identities: Optional[List[str]] = None,
resource_group_name: Optional[str] = None
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
if not system_identity and user_identities is None:
raise RequiredArgumentMissingError(
'No identities provided to remove. Please provide system (--system) or user-assigned identities (--user).'
)
# Turn off system managed identity
if system_identity:
if target_instance.identity.type not in [
IdentityType.system_assigned.value,
IdentityType.system_assigned_user_assigned.value
]:
raise ArgumentUsageError('Digital Twin {} is not currently using a system-assigned identity'.format(name))
elif target_instance.identity.type == IdentityType.system_assigned_user_assigned.value:
target_instance.identity.type = IdentityType.user_assigned
else:
target_instance.identity.type = IdentityType.none.value
if user_identities and target_instance.identity.user_assigned_identities:
# loop through user_identities to remove
for identity in user_identities:
if not target_instance.identity.user_assigned_identities.get(identity):
raise ArgumentUsageError(
'Digital Twin {0} is not currently using a user-assigned identity with id: {1}'.format(name, identity)
)
del target_instance.identity.user_assigned_identities[identity]
if not target_instance.identity.user_assigned_identities:
target_instance.identity.user_assigned_identities = None
elif isinstance(user_identities, list):
target_instance.identity.user_assigned_identities = None
if target_instance.identity.type in [
IdentityType.system_assigned.value,
IdentityType.system_assigned_user_assigned.value
]:
if getattr(target_instance.identity, 'user_assigned_identities', None):
target_instance.identity.type = IdentityType.system_assigned_user_assigned.value
else:
target_instance.identity.type = IdentityType.system_assigned.value
else:
if getattr(target_instance.identity, 'user_assigned_identities', None):
target_instance.identity.type = IdentityType.user_assigned.value
else:
target_instance.identity.type = IdentityType.none.value
try:
update_poller = self.mgmt_sdk.digital_twins.create_or_update(
resource_name=name,
resource_group_name=resource_group_name,
digital_twins_create=target_instance,
long_running_operation_timeout=LRO_TIMER,
)
return update_poller
except ErrorResponseException as err:
handle_service_exception(err)
def show_identity(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
return target_instance.identity
# Endpoints
def get_endpoint(
self, name: str, endpoint_name: str, resource_group_name: Optional[str] = None, wait: bool = False
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.digital_twins_endpoint.get(
resource_name=target_instance.name,
endpoint_name=endpoint_name,
resource_group_name=resource_group_name,
)
except ErrorResponseException as e:
if wait:
e.status_code = e.response.status_code
raise e
handle_service_exception(e)
def list_endpoints(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.digital_twins_endpoint.list(
resource_name=target_instance.name,
resource_group_name=resource_group_name,
)
except ErrorResponseException as e:
handle_service_exception(e)
# TODO: Polling issue related to mismatched status codes.
def delete_endpoint(self, name: str, endpoint_name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.digital_twins_endpoint.delete(
resource_name=target_instance.name,
endpoint_name=endpoint_name,
resource_group_name=resource_group_name,
)
except ErrorResponseException as e:
handle_service_exception(e)
def add_endpoint(
self,
name: str,
endpoint_name: str,
endpoint_resource_type: str,
endpoint_resource_name: str,
endpoint_resource_group: Optional[str] = None,
endpoint_resource_policy: Optional[str] = None,
endpoint_resource_namespace: Optional[str] = None,
endpoint_subscription: Optional[str] = None,
dead_letter_uri: Optional[str] = None,
dead_letter_secret: Optional[str] = None,
resource_group_name: Optional[str] = None,
timeout: int = 20,
auth_type: Optional[str] = None,
system_identity: bool = False,
user_identity: Optional[str] = None,
):
from azext_iot.digitaltwins.common import ADTEndpointType
if system_identity and user_identity:
raise MutuallyExclusiveArgumentError(
"Only one type of identity is permitted for endpoint creation."
)
identity = SYSTEM_IDENTITY if system_identity else user_identity
# do not break users who are still using auth_type to make identity based endpoints
if not identity and auth_type == ADTEndpointAuthType.identitybased.value:
identity = SYSTEM_IDENTITY
# Determine auth type from identity
auth_type = ADTEndpointAuthType.identitybased.value if identity else ADTEndpointAuthType.keybased.value
requires_namespace = [
ADTEndpointType.eventhub.value,
ADTEndpointType.servicebus.value,
]
if endpoint_resource_type in requires_namespace:
if not endpoint_resource_namespace:
raise RequiredArgumentMissingError(
"Endpoint resources of type {} require a namespace.".format(
" or ".join(map(str, requires_namespace))
)
)
if (
auth_type == ADTEndpointAuthType.keybased.value
and not endpoint_resource_policy
):
raise RequiredArgumentMissingError(
"Endpoint resources of type {} require a policy name when using Key based integration.".format(
" or ".join(map(str, requires_namespace))
)
)
if dead_letter_uri and auth_type == ADTEndpointAuthType.keybased.value:
raise RequiredArgumentMissingError(
"Use --deadletter-sas-uri to support deadletter for a Key based endpoint."
)
if dead_letter_secret and auth_type == ADTEndpointAuthType.identitybased.value:
raise RequiredArgumentMissingError(
"Use --deadletter-uri to support deadletter for an Identity based endpoint."
)
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
endpoint_resource_group = endpoint_resource_group or resource_group_name
from azext_iot.digitaltwins.providers.endpoint.builders import build_endpoint
properties = build_endpoint(
endpoint_resource_type=endpoint_resource_type,
endpoint_resource_name=endpoint_resource_name,
endpoint_resource_group=endpoint_resource_group,
endpoint_subscription=endpoint_subscription,
endpoint_resource_namespace=endpoint_resource_namespace,
endpoint_resource_policy=endpoint_resource_policy,
auth_type=auth_type,
dead_letter_secret=dead_letter_secret,
dead_letter_uri=dead_letter_uri,
identity=identity,
)
try:
return self.mgmt_sdk.digital_twins_endpoint.create_or_update(
resource_name=target_instance.name,
resource_group_name=resource_group_name,
endpoint_name=endpoint_name,
properties=properties,
long_running_operation_timeout=timeout,
)
except ErrorResponseException as e:
handle_service_exception(e)
def get_private_link(self, name: str, link_name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.private_link_resources.get(
resource_group_name=resource_group_name,
resource_name=name,
resource_id=link_name,
raw=True,
).response.json()
except ErrorResponseException as e:
handle_service_exception(e)
def list_private_links(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
# This resource is not paged though it may have been the intent.
link_collection = self.mgmt_sdk.private_link_resources.list(
resource_group_name=resource_group_name, resource_name=name, raw=True
).response.json()
return link_collection.get("value", [])
except ErrorResponseException as e:
handle_service_exception(e)
def set_private_endpoint_conn(
self,
name: str,
conn_name: str,
status: str,
description: str,
actions_required: Optional[str] = None,
group_ids: Optional[List[str]] = None,
resource_group_name: Optional[str] = None,
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.private_endpoint_connections.create_or_update(
resource_group_name=resource_group_name,
resource_name=name,
private_endpoint_connection_name=conn_name,
properties={
"privateLinkServiceConnectionState": {
"status": status,
"description": description,
"actions_required": actions_required,
},
"groupIds": group_ids,
},
)
except ErrorResponseException as e:
handle_service_exception(e)
def get_private_endpoint_conn(
self,
name: str,
conn_name: str,
resource_group_name: Optional[str] = None,
wait: bool = False
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.private_endpoint_connections.get(
resource_group_name=resource_group_name,
resource_name=name,
private_endpoint_connection_name=conn_name
)
except ErrorResponseException as e:
if wait:
e.status_code = e.response.status_code
raise e
handle_service_exception(e)
def list_private_endpoint_conns(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
# This resource is not paged though it may have been the intent.
endpoint_collection = self.mgmt_sdk.private_endpoint_connections.list(
resource_group_name=resource_group_name, resource_name=name, raw=True
).response.json()
return endpoint_collection.get("value", [])
except ErrorResponseException as e:
handle_service_exception(e)
def delete_private_endpoint_conn(self, name: str, conn_name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.private_endpoint_connections.delete(
resource_group_name=resource_group_name,
resource_name=name,
private_endpoint_connection_name=conn_name
)
except ErrorResponseException as e:
handle_service_exception(e)
def create_adx_data_connection(
self,
name: str,
conn_name: str,
adx_cluster_name: str,
adx_database_name: str,
eh_namespace: str,
eh_entity_path: str,
adx_table_name: str = ADX_DEFAULT_TABLE,
adx_twin_lifecycle_events_table_name: Optional[str] = None,
adx_relationship_lifecycle_events_table_name: Optional[str] = None,
adx_resource_group: Optional[str] = None,
adx_subscription: Optional[str] = None,
eh_consumer_group: str = DEFAULT_CONSUMER_GROUP,
eh_resource_group: Optional[str] = None,
eh_subscription: Optional[str] = None,
user_identity: Optional[str] = None,
resource_group_name: Optional[str] = None,
record_property_and_item_removals: bool = False,
yes: bool = False,
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
subscription = target_instance.id.split("/")[2]
if len(conn_name) <= 2:
raise InvalidArgumentValueError(
"The connection name must have a length greater than 2"
)
adx_resource_group = adx_resource_group if adx_resource_group else resource_group_name
eh_resource_group = eh_resource_group if eh_resource_group else resource_group_name
adx_subscription = adx_subscription if adx_subscription else subscription
eh_subscription = eh_subscription if eh_subscription else subscription
from azext_iot.digitaltwins.providers.connection.builders import build_adx_connection_properties
properties = build_adx_connection_properties(
adx_cluster_name=adx_cluster_name,
adx_database_name=adx_database_name,
adx_table_name=adx_table_name,
adx_twin_lifecycle_events_table_name=adx_twin_lifecycle_events_table_name,
adx_relationship_lifecycle_events_table_name=adx_relationship_lifecycle_events_table_name,
adx_resource_group=adx_resource_group,
adx_subscription=adx_subscription,
dt_instance=target_instance,
eh_namespace=eh_namespace,
eh_entity_path=eh_entity_path,
eh_consumer_group=eh_consumer_group,
eh_resource_group=eh_resource_group,
eh_subscription=eh_subscription,
identity=user_identity or SYSTEM_IDENTITY,
record_property_and_item_removals=record_property_and_item_removals,
yes=yes,
)
try:
def check_state(lro):
generic_check_state(
lro=lro,
show_cmd="az dt data-history show -n {} -g {} --cn {}".format(
name, resource_group_name, conn_name
),
max_retries=MAX_ADT_DH_CREATE_RETRIES
)
create_or_update = self.mgmt_sdk.time_series_database_connections.create_or_update(
resource_group_name=resource_group_name,
resource_name=name,
time_series_database_connection_name=conn_name,
properties=properties
)
create_or_update.add_done_callback(check_state)
return create_or_update
except ErrorResponseException as e:
handle_service_exception(e)
def get_data_connection(
self, name: str, conn_name: str, resource_group_name: Optional[str] = None, wait: bool = False
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.time_series_database_connections.get(
resource_group_name=resource_group_name,
resource_name=name,
time_series_database_connection_name=conn_name,
)
except ErrorResponseException as e:
if wait:
e.status_code = e.response.status_code
raise e
handle_service_exception(e)
def list_data_connection(self, name: str, resource_group_name: Optional[str] = None):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.time_series_database_connections.list(
resource_group_name=resource_group_name,
resource_name=name,
)
except ErrorResponseException as e:
handle_service_exception(e)
def delete_data_connection(
self,
name: str,
conn_name: str,
cleanup_connection_artifacts: bool = False,
resource_group_name: Optional[str] = None
):
target_instance = self.find_instance(
name=name, resource_group_name=resource_group_name
)
if not resource_group_name:
resource_group_name = self.get_rg(target_instance)
try:
return self.mgmt_sdk.time_series_database_connections.delete(
resource_group_name=resource_group_name,
resource_name=name,
time_series_database_connection_name=conn_name,
cleanup_connection_artifacts=cleanup_connection_artifacts,
)
except ErrorResponseException as e:
handle_service_exception(e)