azext_edge/edge/providers/rpsaas/adr/asset_endpoint_profiles.py (377 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- import json from rich.console import Console from typing import TYPE_CHECKING, Dict, Iterable, Optional from knack.log import get_logger from azure.cli.core.azclierror import ( InvalidArgumentValueError, MutuallyExclusiveArgumentError, RequiredArgumentMissingError, FileOperationError, ) from .user_strings import ( AUTH_REF_MISMATCH_ERROR, GENERAL_AUTH_REF_MISMATCH_ERROR, MISSING_USERPASS_REF_ERROR, REMOVED_CERT_REF_MSG, REMOVED_USERPASS_REF_MSG, ) from ....util.az_client import get_registry_mgmt_client, wait_for_terminal_state, REGISTRY_API_VERSION from ....util.queryable import Queryable from ....common import AEPAuthModes, AEPTypes if TYPE_CHECKING: from ....vendor.clients.deviceregistrymgmt.operations import ( AssetEndpointProfilesOperations as AEPOperations ) console = Console() logger = get_logger(__name__) AEP_RESOURCE_TYPE = "Microsoft.DeviceRegistry/assetEndpointProfiles" # TODO: soul searching to see if I should combine with assets class class AssetEndpointProfiles(Queryable): def __init__(self, cmd): super().__init__(cmd=cmd) self.deviceregistry_mgmt_client = get_registry_mgmt_client( subscription_id=self.default_subscription_id, api_version=REGISTRY_API_VERSION ) self.ops: "AEPOperations" = self.deviceregistry_mgmt_client.asset_endpoint_profiles def create( self, asset_endpoint_profile_name: str, endpoint_profile_type: str, instance_name: str, resource_group_name: str, target_address: str, certificate_reference: Optional[str] = None, instance_resource_group: Optional[str] = None, instance_subscription: Optional[str] = None, location: Optional[str] = None, password_reference: Optional[str] = None, username_reference: Optional[str] = None, tags: Optional[Dict[str, str]] = None, **kwargs ): from .helpers import get_extended_location extended_location = get_extended_location( cmd=self.cmd, instance_name=instance_name, instance_resource_group=instance_resource_group or resource_group_name, instance_subscription=instance_subscription ) cluster_location = extended_location.pop("cluster_location") auth_mode = None if not any([username_reference, password_reference, certificate_reference]): auth_mode = AEPAuthModes.anonymous.value # Properties properties = {"endpointProfileType": endpoint_profile_type} configuration = None if endpoint_profile_type == AEPTypes.opcua.value: configuration = _build_opcua_config(**kwargs) elif "additional_configuration" in kwargs: # custom type configuration = _process_additional_configuration(kwargs["additional_configuration"]) _update_properties( properties, target_address=target_address, auth_mode=auth_mode, username_reference=username_reference, password_reference=password_reference, certificate_reference=certificate_reference, additional_configuration=configuration ) aep_body = { "extendedLocation": extended_location, "location": location or cluster_location, "properties": properties, "tags": tags, } with console.status(f"Creating {asset_endpoint_profile_name}..."): poller = self.ops.begin_create_or_replace( resource_group_name, asset_endpoint_profile_name, resource=aep_body ) return wait_for_terminal_state(poller, **kwargs) def delete(self, asset_endpoint_profile_name: str, resource_group_name: str, **kwargs): self.show( asset_endpoint_profile_name=asset_endpoint_profile_name, resource_group_name=resource_group_name, check_cluster=True ) with console.status(f"Deleting {asset_endpoint_profile_name}..."): poller = self.ops.begin_delete( resource_group_name, asset_endpoint_profile_name, ) return wait_for_terminal_state(poller, **kwargs) def show( self, asset_endpoint_profile_name: str, resource_group_name: str, check_cluster: bool = False ) -> dict: asset_endpoint = self.ops.get( resource_group_name=resource_group_name, asset_endpoint_profile_name=asset_endpoint_profile_name ) self.ops = self.ops if check_cluster: from .helpers import check_cluster_connectivity check_cluster_connectivity(self.cmd, asset_endpoint) return asset_endpoint def list(self, resource_group_name: Optional[str] = None, discovered: bool = False) -> Iterable[dict]: if resource_group_name: return self.ops.list_by_resource_group(resource_group_name=resource_group_name) return self.ops.list_by_subscription() # TODO: unit test def query_asset_endpoint_profiles( self, asset_endpoint_profile_name: Optional[str] = None, auth_mode: Optional[str] = None, custom_query: Optional[str] = None, endpoint_profile_type: Optional[str] = None, instance_name: Optional[str] = None, instance_resource_group: Optional[str] = None, location: Optional[str] = None, resource_group_name: Optional[str] = None, target_address: Optional[str] = None, ) -> dict: query_body = custom_query or _build_query_body( asset_endpoint_profile_name=asset_endpoint_profile_name, auth_mode=auth_mode, endpoint_profile_type=endpoint_profile_type, location=location, resource_group_name=resource_group_name, target_address=target_address ) query = f"Resources | where type =~\"{AEP_RESOURCE_TYPE}\" " + query_body if any([instance_name, instance_resource_group]): instance_query = "Resources | where type =~ 'microsoft.iotoperations/instances' " if instance_name: instance_query += f"| where name =~ \"{instance_name}\"" if instance_resource_group: instance_query += f"| where resourceGroup =~ \"{instance_resource_group}\"" # fetch the custom location + join on innerunique. Then remove the extra customLocation1 generated query = f"{instance_query} | extend customLocation = tostring(extendedLocation.name) "\ f"| project customLocation | join kind=innerunique ({query}) on customLocation "\ "| project-away customLocation1" return self.query(query=query) def update( self, asset_endpoint_profile_name: str, resource_group_name: str, target_address: Optional[str] = None, auth_mode: Optional[str] = None, username_reference: Optional[str] = None, password_reference: Optional[str] = None, certificate_reference: Optional[str] = None, tags: Optional[Dict[str, str]] = None, **kwargs ): # get the asset original_aep = self.show( asset_endpoint_profile_name, resource_group_name=resource_group_name, check_cluster=True ) if tags: original_aep["tags"] = tags # modify the asset endpoint profile properties = original_aep.get("properties", {}) _update_properties( properties, target_address=target_address, auth_mode=auth_mode, username_reference=username_reference, password_reference=password_reference, certificate_reference=certificate_reference, ) # use this over update since we want to make sure we get the tags in with console.status(f"Updating {asset_endpoint_profile_name}..."): poller = self.ops.begin_create_or_replace( resource_group_name, asset_endpoint_profile_name, original_aep ) return wait_for_terminal_state(poller, **kwargs) # Helpers def _assert_above_min(param: str, value: int, minimum: int = 0) -> str: if value < minimum: return f"The parameter {param} needs to be at least {minimum}.\n" return "" def _raise_if_connector_error(connector_type: str, error_msg: str): if error_msg: raise InvalidArgumentValueError( f"The following {connector_type} connector arguments are invalid:\n{error_msg}" ) # TODO: use jsonschema lib def _build_opcua_config( original_config: Optional[str] = None, application_name: Optional[str] = None, auto_accept_untrusted_server_certs: Optional[bool] = None, default_publishing_interval: Optional[int] = None, default_sampling_interval: Optional[int] = None, default_queue_size: Optional[int] = None, keep_alive: Optional[int] = None, run_asset_discovery: Optional[str] = None, session_timeout: Optional[int] = None, session_keep_alive: Optional[int] = None, session_reconnect_period: Optional[int] = None, session_reconnect_exponential_back_off: Optional[int] = None, security_policy: Optional[str] = None, security_mode: Optional[str] = None, sub_max_items: Optional[int] = None, sub_life_time: Optional[int] = None, **_ ) -> str: config = json.loads(original_config) if original_config else {} error_msg = "" if application_name: config["applicationName"] = application_name if keep_alive: error_msg += _assert_above_min("--keep-alive", keep_alive) config["keepAliveMilliseconds"] = keep_alive if run_asset_discovery is not None: config["runAssetDiscovery"] = run_asset_discovery # defaults if any([ default_publishing_interval, default_sampling_interval, default_queue_size ]) and not config.get("defaults"): config["defaults"] = {} if default_publishing_interval: error_msg += _assert_above_min("--default-publishing-int", default_publishing_interval, -1) config["defaults"]["publishingIntervalMilliseconds"] = default_publishing_interval if default_sampling_interval: error_msg += _assert_above_min("--default-sampling-int", default_sampling_interval, -1) config["defaults"]["samplingIntervalMilliseconds"] = default_sampling_interval if default_queue_size: error_msg += _assert_above_min("--default-queue-size", default_queue_size, 0) config["defaults"]["queueSize"] = default_queue_size # session if any([ session_timeout, session_reconnect_period, session_keep_alive, session_reconnect_exponential_back_off ]) and not config.get("session"): config["session"] = {} if session_timeout: error_msg += _assert_above_min("--session-timeout", session_timeout) config["session"]["timeoutMilliseconds"] = session_timeout if session_keep_alive: error_msg += _assert_above_min("--session-keep-alive", session_keep_alive) config["session"]["keepAliveIntervalMilliseconds"] = session_keep_alive if session_reconnect_period: error_msg += _assert_above_min("--session-reconnect-period", session_reconnect_period) config["session"]["reconnectPeriodMilliseconds"] = session_reconnect_period if session_reconnect_exponential_back_off: error_msg += _assert_above_min("--session-reconnect-backoff", session_reconnect_exponential_back_off, -1) config["session"]["reconnectExponentialBackOffMilliseconds"] = session_reconnect_exponential_back_off # subscription if any([sub_life_time, sub_max_items]) and not config.get("subscription"): config["subscription"] = {} if sub_life_time: error_msg += _assert_above_min("--subscription-life-time", sub_life_time) config["subscription"]["lifeTimeMilliseconds"] = sub_life_time if sub_max_items: error_msg += _assert_above_min("--subscription-max-items", sub_max_items, 1) config["subscription"]["maxItems"] = sub_max_items # security if any([ auto_accept_untrusted_server_certs is not None, security_mode, security_policy ]) and not config.get("security"): config["security"] = {} if auto_accept_untrusted_server_certs is not None: config["security"]["autoAcceptUntrustedServerCertificates"] = auto_accept_untrusted_server_certs if security_mode: config["security"]["securityMode"] = security_mode if security_policy: config["security"]["securityPolicy"] = "http://opcfoundation.org/UA/SecurityPolicy#" + security_policy _raise_if_connector_error(connector_type="OPCUA", error_msg=error_msg) return json.dumps(config) def _build_query_body( asset_endpoint_profile_name: Optional[str] = None, auth_mode: Optional[str] = None, endpoint_profile_type: Optional[str] = None, location: Optional[str] = None, resource_group_name: Optional[str] = None, target_address: Optional[str] = None, ) -> str: query_body = "" if resource_group_name: query_body += f"| where resourceGroup =~ \"{resource_group_name}\"" if location: query_body += f"| where location =~ \"{location}\"" if asset_endpoint_profile_name: query_body += f"| where name =~ \"{asset_endpoint_profile_name}\"" if auth_mode: query_body += f"| where properties.authentication.method =~ \"{auth_mode}\"" if endpoint_profile_type: query_body += f"| where properties.endpointProfileType =~ \"{endpoint_profile_type}\"" if target_address: query_body += f"| where properties.targetAddress =~ \"{target_address}\"" query_body += "| extend customLocation = tostring(extendedLocation.name) "\ "| extend provisioningState = properties.provisioningState "\ "| project id, customLocation, location, name, resourceGroup, provisioningState, tags, "\ "type, subscriptionId " return query_body def _process_additional_configuration(configuration: str) -> Optional[str]: from ....util import read_file_content inline_json = False if not configuration: return try: logger.debug("Processing additional configuration.") configuration = read_file_content(configuration) if not configuration: raise InvalidArgumentValueError("Given file is empty.") except FileOperationError: inline_json = True logger.debug("Given additional configuration is not a file.") # make sure it is an actual json try: json.loads(configuration) return configuration except json.JSONDecodeError as e: error_msg = "Additional configuration is not a valid JSON. " if inline_json: error_msg += "For examples of valid JSON formating, please see https://aka.ms/inline-json-examples " raise InvalidArgumentValueError( f"{error_msg}\n{e.msg}" ) def _process_authentication( auth_mode: Optional[str] = None, auth_props: Optional[Dict[str, str]] = None, certificate_reference: Optional[str] = None, password_reference: Optional[str] = None, username_reference: Optional[str] = None ) -> Dict[str, str]: if not auth_props: auth_props = {} # add checking for ensuring auth mode is set with proper params if certificate_reference and (username_reference or password_reference): raise MutuallyExclusiveArgumentError(AUTH_REF_MISMATCH_ERROR) if certificate_reference and auth_mode in [None, AEPAuthModes.certificate.value]: auth_props["method"] = AEPAuthModes.certificate.value auth_props["x509Credentials"] = {"certificateSecretName": certificate_reference} if auth_props.pop("usernamePasswordCredentials", None): logger.warning(REMOVED_USERPASS_REF_MSG) elif (username_reference or password_reference) and auth_mode in [None, AEPAuthModes.userpass.value]: auth_props["method"] = AEPAuthModes.userpass.value user_creds = auth_props.get("usernamePasswordCredentials", {}) user_creds["usernameSecretName"] = username_reference user_creds["passwordSecretName"] = password_reference if not all([user_creds["usernameSecretName"], user_creds["passwordSecretName"]]): raise RequiredArgumentMissingError(MISSING_USERPASS_REF_ERROR) auth_props["usernamePasswordCredentials"] = user_creds if auth_props.pop("x509Credentials", None): logger.warning(REMOVED_CERT_REF_MSG) elif auth_mode == AEPAuthModes.anonymous.value and not any( [certificate_reference, username_reference, password_reference] ): auth_props["method"] = AEPAuthModes.anonymous.value if auth_props.pop("x509Credentials", None): logger.warning(REMOVED_CERT_REF_MSG) if auth_props.pop("usernamePasswordCredentials", None): logger.warning(REMOVED_USERPASS_REF_MSG) elif any([auth_mode, certificate_reference, username_reference, password_reference]): raise MutuallyExclusiveArgumentError(GENERAL_AUTH_REF_MISMATCH_ERROR) return auth_props def _update_properties( properties, target_address: Optional[str] = None, additional_configuration: Optional[str] = None, auth_mode: Optional[str] = None, username_reference: Optional[str] = None, password_reference: Optional[str] = None, certificate_reference: Optional[str] = None, ): if additional_configuration: properties["additionalConfiguration"] = additional_configuration if target_address: properties["targetAddress"] = target_address if any([auth_mode, username_reference, password_reference, certificate_reference]): auth_props = properties.get("authentication", {}) properties["authentication"] = _process_authentication( auth_props=auth_props, auth_mode=auth_mode, certificate_reference=certificate_reference, username_reference=username_reference, password_reference=password_reference )