azext_iot/digitaltwins/providers/connection/builders.py (277 lines of code) (raw):

# coding=utf-8 # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import json from typing import Optional from azure.cli.core.azclierror import BadRequestError, CLIInternalError, ManualInterrupt from azext_iot.common.embedded_cli import EmbeddedCLI from azext_iot.sdk.digitaltwins.controlplane.models import ( AzureDataExplorerConnectionProperties, ManagedIdentityReference, DigitalTwinsIdentityType, ) from knack.log import get_logger from knack.prompting import prompt_y_n from azext_iot.digitaltwins.common import ( DT_INSTANCE, USER_ASSIGNED_IDENTITY, DT_SYS_IDENTITY_ERROR, DT_UAI_IDENTITY_ERROR, ERROR_PREFIX, FINISHED_CHECK_RESOURCE_LOG_MSG, ADX_ROLE_MSG, RBAC_ROLE_MSG, SYSTEM_IDENTITY, TRY_ADD_ROLE_LOG_MSG, FINISHED_ADD_ROLE_LOG_MSG, SKIP_ADD_ROLE_MSG, FAIL_RBAC_MSG, FAIL_GENERIC_MSG, ABORT_MSG, ADD_ROLE_INPUT_MSG, CONT_INPUT_MSG, DEFAULT_CONSUMER_GROUP ) logger = get_logger(__name__) class AdxConnectionValidator(object): def __init__( self, adx_cluster_name: str, adx_database_name: str, adx_resource_group: str, adx_subscription: str, dt_instance, eh_namespace: str, eh_entity_path: str, eh_resource_group: str, eh_subscription: str, eh_consumer_group: str = DEFAULT_CONSUMER_GROUP, identity: str = SYSTEM_IDENTITY, yes: bool = False, ): self.cli = EmbeddedCLI() self.yes = yes self.dt = dt_instance # Check that the identity is associated with the dt if identity == SYSTEM_IDENTITY and ( not self.dt.identity or not self.dt.identity.principal_id ): raise BadRequestError(DT_SYS_IDENTITY_ERROR) elif identity != SYSTEM_IDENTITY and ( not self.dt.identity or not self.dt.identity.user_assigned_identities or identity not in self.dt.identity.user_assigned_identities ): raise BadRequestError(DT_UAI_IDENTITY_ERROR) # set the identity to be principal id for ease principal_id = ( self.dt.identity.principal_id if identity == SYSTEM_IDENTITY else self.dt.identity.user_assigned_identities[identity].principal_id ) # Populate adx_cluster_uri, adx_location, adx_resource_id and perform checks self.validate_adx( adx_cluster_name=adx_cluster_name, adx_database_name=adx_database_name, adx_resource_group=adx_resource_group, adx_subscription=adx_subscription, principal_id=principal_id, ) self.eh_consumer_group = eh_consumer_group # Populate eh_endpoint_uri, eh_namespace_resource_id and perform checks self.validate_eventhub( eh_namespace=eh_namespace, eh_entity_path=eh_entity_path, eh_resource_group=eh_resource_group, eh_subscription=eh_subscription, eh_consumer_group=eh_consumer_group, identity=identity, principal_id=principal_id, ) def validate_eventhub( self, eh_namespace: str, eh_entity_path: str, eh_resource_group: str, eh_subscription: str, eh_consumer_group: str, identity: str, principal_id: str ): from azext_iot.digitaltwins.providers.endpoint.builders import EventHubEndpointBuilder eh_endpoint = EventHubEndpointBuilder( endpoint_resource_name=eh_entity_path, endpoint_resource_group=eh_resource_group, endpoint_resource_namespace=eh_namespace, endpoint_resource_policy=None, endpoint_subscription=eh_subscription, identity=identity ) eh_endpoint.error_prefix = ERROR_PREFIX + " find" self.eh_endpoint_uri = eh_endpoint.build_identity_based().endpoint_uri # Run check only if the consumer group is not the default. Default consumer group will always be present. if eh_consumer_group.lower() != DEFAULT_CONSUMER_GROUP.lower(): eh_consumer_group_op = self.cli.invoke( "eventhubs eventhub consumer-group show -n {} --eventhub-name {} --namespace-name {} -g {}".format( eh_consumer_group, eh_entity_path, eh_namespace, eh_resource_group, ), subscription=eh_subscription, ) if not eh_consumer_group_op.success(): raise CLIInternalError("{} retrieve Event Hub Consumer Group.".format(ERROR_PREFIX)) self.eh_namespace_resource_id = ( "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.EventHub/namespaces/{}".format( eh_subscription, eh_resource_group, eh_namespace, ) ) logger.debug(FINISHED_CHECK_RESOURCE_LOG_MSG.format("Event Hub")) self.add_dt_role_assignment( role="Azure Event Hubs Data Owner", resource_id=f"{self.eh_namespace_resource_id}/eventhubs/{eh_entity_path}", principal_id=principal_id ) def validate_adx( self, adx_cluster_name: str, adx_database_name: str, adx_resource_group: str, adx_subscription: str, principal_id: str, ): api_version = "api-version=2021-01-01" self.adx_resource_id = ( "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Kusto/clusters/{}".format( adx_subscription, adx_resource_group, adx_cluster_name ) ) adx_cluster_op = self.cli.invoke( "rest --method get --url {}?{}".format( self.adx_resource_id, api_version ) ) if not adx_cluster_op.success(): raise CLIInternalError("{} retrieve Cluster.".format(ERROR_PREFIX)) adx_cluster_meta = adx_cluster_op.as_json() self.adx_cluster_uri = adx_cluster_meta["properties"]["uri"] self.adx_location = adx_cluster_meta["location"].lower().replace(" ", "") adx_database_op = self.cli.invoke( "rest --method get --url {}/databases/{}?{}".format( self.adx_resource_id, adx_database_name, api_version ) ) if not adx_database_op.success(): raise CLIInternalError("{} retrieve Database.".format(ERROR_PREFIX)) logger.debug(FINISHED_CHECK_RESOURCE_LOG_MSG.format("Azure Data Explorer")) self.add_dt_role_assignment( role="Contributor", resource_id=f"{self.adx_resource_id}/databases/{adx_database_name}", principal_id=principal_id ) self.add_adx_principal(adx_database_name, api_version, principal_id) def add_dt_role_assignment(self, role: str, resource_id: str, principal_id: str): assignee = DT_INSTANCE if self.dt.identity.principal_id == principal_id else USER_ASSIGNED_IDENTITY role_str = RBAC_ROLE_MSG.format(role, assignee, resource_id) logger.debug(TRY_ADD_ROLE_LOG_MSG.format(role_str)) if not (self.yes or prompt_y_n(msg=ADD_ROLE_INPUT_MSG.format(role_str), default="y")): print(SKIP_ADD_ROLE_MSG.format(role_str)) return role_command = ( "role assignment create --role '{}' --assignee-object-id {} " "--assignee-principal-type ServicePrincipal --scope {}".format( role, principal_id, resource_id ) ) role_op = self.cli.invoke(role_command) if not role_op.success(): print(FAIL_RBAC_MSG.format(role_str, role_command)) if not prompt_y_n(msg=CONT_INPUT_MSG, default="n"): raise ManualInterrupt(ABORT_MSG) logger.debug(FINISHED_ADD_ROLE_LOG_MSG.format(role_str)) def add_adx_principal(self, adx_database_name: str, api_version: str, principal_id: str): assignee = DT_INSTANCE if self.dt.identity.principal_id == principal_id else USER_ASSIGNED_IDENTITY role_str = ADX_ROLE_MSG.format(assignee, adx_database_name) logger.debug(TRY_ADD_ROLE_LOG_MSG.format(role_str)) if not (self.yes or prompt_y_n(msg=ADD_ROLE_INPUT_MSG.format(role_str), default="y")): print(SKIP_ADD_ROLE_MSG.format(role_str)) return database_admin_op = self.cli.invoke( "rest --method POST --url {}/databases/{}/addPrincipals?{} -b '{}'".format( self.adx_resource_id, adx_database_name, api_version, json.dumps({ "value": [{ "role": "Admin", "name": self.dt.name, "type": "App", "appId": principal_id, }] }) ) ) if not database_admin_op.success(): print(FAIL_GENERIC_MSG.format(role_str)) if not prompt_y_n(msg=CONT_INPUT_MSG, default="n"): raise ManualInterrupt(ABORT_MSG) return logger.debug(FINISHED_ADD_ROLE_LOG_MSG.format(role_str)) def build_adx_connection_properties( adx_cluster_name: str, adx_database_name: str, dt_instance, eh_namespace: str, eh_entity_path: str, adx_table_name: Optional[str] = None, adx_twin_lifecycle_events_table_name: Optional[str] = None, adx_relationship_lifecycle_events_table_name: Optional[str] = None, adx_resource_group: Optional[str] = None, adx_subscription: Optional[str] = None, eh_resource_group: Optional[str] = None, eh_subscription: Optional[str] = None, eh_consumer_group: str = DEFAULT_CONSUMER_GROUP, identity: str = SYSTEM_IDENTITY, record_property_and_item_removals: bool = False, yes: bool = False, ): validator = AdxConnectionValidator( adx_cluster_name=adx_cluster_name, adx_database_name=adx_database_name, adx_resource_group=adx_resource_group, adx_subscription=adx_subscription, dt_instance=dt_instance, eh_namespace=eh_namespace, eh_entity_path=eh_entity_path, eh_consumer_group=eh_consumer_group, eh_resource_group=eh_resource_group, eh_subscription=eh_subscription, identity=identity, yes=yes, ) if identity == SYSTEM_IDENTITY: identity = ManagedIdentityReference( type=DigitalTwinsIdentityType.system_assigned.value ) else: identity = ManagedIdentityReference( type=DigitalTwinsIdentityType.user_assigned.value, user_assigned_identity=identity ) return AzureDataExplorerConnectionProperties( adx_resource_id=validator.adx_resource_id, adx_endpoint_uri=validator.adx_cluster_uri, adx_database_name=adx_database_name, adx_table_name=adx_table_name, adx_twin_lifecycle_events_table_name=adx_twin_lifecycle_events_table_name, adx_relationship_lifecycle_events_table_name=adx_relationship_lifecycle_events_table_name, event_hub_endpoint_uri=validator.eh_endpoint_uri, event_hub_entity_path=eh_entity_path, event_hub_consumer_group=eh_consumer_group, event_hub_namespace_resource_id=validator.eh_namespace_resource_id, identity=identity, record_property_and_item_removals=record_property_and_item_removals )