azext_edge/edge/providers/rpsaas/adr/asset_endpoint_profiles.py (377 lines of code) (raw):
# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------
import json
from rich.console import Console
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from knack.log import get_logger
from azure.cli.core.azclierror import (
InvalidArgumentValueError,
MutuallyExclusiveArgumentError,
RequiredArgumentMissingError,
FileOperationError,
)
from .user_strings import (
AUTH_REF_MISMATCH_ERROR,
GENERAL_AUTH_REF_MISMATCH_ERROR,
MISSING_USERPASS_REF_ERROR,
REMOVED_CERT_REF_MSG,
REMOVED_USERPASS_REF_MSG,
)
from ....util.az_client import get_registry_mgmt_client, wait_for_terminal_state, REGISTRY_API_VERSION
from ....util.queryable import Queryable
from ....common import AEPAuthModes, AEPTypes
if TYPE_CHECKING:
from ....vendor.clients.deviceregistrymgmt.operations import (
AssetEndpointProfilesOperations as AEPOperations
)
console = Console()
logger = get_logger(__name__)
AEP_RESOURCE_TYPE = "Microsoft.DeviceRegistry/assetEndpointProfiles"
# TODO: soul searching to see if I should combine with assets class
class AssetEndpointProfiles(Queryable):
def __init__(self, cmd):
super().__init__(cmd=cmd)
self.deviceregistry_mgmt_client = get_registry_mgmt_client(
subscription_id=self.default_subscription_id,
api_version=REGISTRY_API_VERSION
)
self.ops: "AEPOperations" = self.deviceregistry_mgmt_client.asset_endpoint_profiles
def create(
self,
asset_endpoint_profile_name: str,
endpoint_profile_type: str,
instance_name: str,
resource_group_name: str,
target_address: str,
certificate_reference: Optional[str] = None,
instance_resource_group: Optional[str] = None,
instance_subscription: Optional[str] = None,
location: Optional[str] = None,
password_reference: Optional[str] = None,
username_reference: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
**kwargs
):
from .helpers import get_extended_location
extended_location = get_extended_location(
cmd=self.cmd,
instance_name=instance_name,
instance_resource_group=instance_resource_group or resource_group_name,
instance_subscription=instance_subscription
)
cluster_location = extended_location.pop("cluster_location")
auth_mode = None
if not any([username_reference, password_reference, certificate_reference]):
auth_mode = AEPAuthModes.anonymous.value
# Properties
properties = {"endpointProfileType": endpoint_profile_type}
configuration = None
if endpoint_profile_type == AEPTypes.opcua.value:
configuration = _build_opcua_config(**kwargs)
elif "additional_configuration" in kwargs: # custom type
configuration = _process_additional_configuration(kwargs["additional_configuration"])
_update_properties(
properties,
target_address=target_address,
auth_mode=auth_mode,
username_reference=username_reference,
password_reference=password_reference,
certificate_reference=certificate_reference,
additional_configuration=configuration
)
aep_body = {
"extendedLocation": extended_location,
"location": location or cluster_location,
"properties": properties,
"tags": tags,
}
with console.status(f"Creating {asset_endpoint_profile_name}..."):
poller = self.ops.begin_create_or_replace(
resource_group_name,
asset_endpoint_profile_name,
resource=aep_body
)
return wait_for_terminal_state(poller, **kwargs)
def delete(self, asset_endpoint_profile_name: str, resource_group_name: str, **kwargs):
self.show(
asset_endpoint_profile_name=asset_endpoint_profile_name,
resource_group_name=resource_group_name,
check_cluster=True
)
with console.status(f"Deleting {asset_endpoint_profile_name}..."):
poller = self.ops.begin_delete(
resource_group_name,
asset_endpoint_profile_name,
)
return wait_for_terminal_state(poller, **kwargs)
def show(
self, asset_endpoint_profile_name: str, resource_group_name: str, check_cluster: bool = False
) -> dict:
asset_endpoint = self.ops.get(
resource_group_name=resource_group_name, asset_endpoint_profile_name=asset_endpoint_profile_name
)
self.ops = self.ops
if check_cluster:
from .helpers import check_cluster_connectivity
check_cluster_connectivity(self.cmd, asset_endpoint)
return asset_endpoint
def list(self, resource_group_name: Optional[str] = None, discovered: bool = False) -> Iterable[dict]:
if resource_group_name:
return self.ops.list_by_resource_group(resource_group_name=resource_group_name)
return self.ops.list_by_subscription()
# TODO: unit test
def query_asset_endpoint_profiles(
self,
asset_endpoint_profile_name: Optional[str] = None,
auth_mode: Optional[str] = None,
custom_query: Optional[str] = None,
endpoint_profile_type: Optional[str] = None,
instance_name: Optional[str] = None,
instance_resource_group: Optional[str] = None,
location: Optional[str] = None,
resource_group_name: Optional[str] = None,
target_address: Optional[str] = None,
) -> dict:
query_body = custom_query or _build_query_body(
asset_endpoint_profile_name=asset_endpoint_profile_name,
auth_mode=auth_mode,
endpoint_profile_type=endpoint_profile_type,
location=location,
resource_group_name=resource_group_name,
target_address=target_address
)
query = f"Resources | where type =~\"{AEP_RESOURCE_TYPE}\" " + query_body
if any([instance_name, instance_resource_group]):
instance_query = "Resources | where type =~ 'microsoft.iotoperations/instances' "
if instance_name:
instance_query += f"| where name =~ \"{instance_name}\""
if instance_resource_group:
instance_query += f"| where resourceGroup =~ \"{instance_resource_group}\""
# fetch the custom location + join on innerunique. Then remove the extra customLocation1 generated
query = f"{instance_query} | extend customLocation = tostring(extendedLocation.name) "\
f"| project customLocation | join kind=innerunique ({query}) on customLocation "\
"| project-away customLocation1"
return self.query(query=query)
def update(
self,
asset_endpoint_profile_name: str,
resource_group_name: str,
target_address: Optional[str] = None,
auth_mode: Optional[str] = None,
username_reference: Optional[str] = None,
password_reference: Optional[str] = None,
certificate_reference: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
**kwargs
):
# get the asset
original_aep = self.show(
asset_endpoint_profile_name,
resource_group_name=resource_group_name,
check_cluster=True
)
if tags:
original_aep["tags"] = tags
# modify the asset endpoint profile
properties = original_aep.get("properties", {})
_update_properties(
properties,
target_address=target_address,
auth_mode=auth_mode,
username_reference=username_reference,
password_reference=password_reference,
certificate_reference=certificate_reference,
)
# use this over update since we want to make sure we get the tags in
with console.status(f"Updating {asset_endpoint_profile_name}..."):
poller = self.ops.begin_create_or_replace(
resource_group_name,
asset_endpoint_profile_name,
original_aep
)
return wait_for_terminal_state(poller, **kwargs)
# Helpers
def _assert_above_min(param: str, value: int, minimum: int = 0) -> str:
if value < minimum:
return f"The parameter {param} needs to be at least {minimum}.\n"
return ""
def _raise_if_connector_error(connector_type: str, error_msg: str):
if error_msg:
raise InvalidArgumentValueError(
f"The following {connector_type} connector arguments are invalid:\n{error_msg}"
)
# TODO: use jsonschema lib
def _build_opcua_config(
original_config: Optional[str] = None,
application_name: Optional[str] = None,
auto_accept_untrusted_server_certs: Optional[bool] = None,
default_publishing_interval: Optional[int] = None,
default_sampling_interval: Optional[int] = None,
default_queue_size: Optional[int] = None,
keep_alive: Optional[int] = None,
run_asset_discovery: Optional[str] = None,
session_timeout: Optional[int] = None,
session_keep_alive: Optional[int] = None,
session_reconnect_period: Optional[int] = None,
session_reconnect_exponential_back_off: Optional[int] = None,
security_policy: Optional[str] = None,
security_mode: Optional[str] = None,
sub_max_items: Optional[int] = None,
sub_life_time: Optional[int] = None,
**_
) -> str:
config = json.loads(original_config) if original_config else {}
error_msg = ""
if application_name:
config["applicationName"] = application_name
if keep_alive:
error_msg += _assert_above_min("--keep-alive", keep_alive)
config["keepAliveMilliseconds"] = keep_alive
if run_asset_discovery is not None:
config["runAssetDiscovery"] = run_asset_discovery
# defaults
if any([
default_publishing_interval, default_sampling_interval, default_queue_size
]) and not config.get("defaults"):
config["defaults"] = {}
if default_publishing_interval:
error_msg += _assert_above_min("--default-publishing-int", default_publishing_interval, -1)
config["defaults"]["publishingIntervalMilliseconds"] = default_publishing_interval
if default_sampling_interval:
error_msg += _assert_above_min("--default-sampling-int", default_sampling_interval, -1)
config["defaults"]["samplingIntervalMilliseconds"] = default_sampling_interval
if default_queue_size:
error_msg += _assert_above_min("--default-queue-size", default_queue_size, 0)
config["defaults"]["queueSize"] = default_queue_size
# session
if any([
session_timeout, session_reconnect_period, session_keep_alive, session_reconnect_exponential_back_off
]) and not config.get("session"):
config["session"] = {}
if session_timeout:
error_msg += _assert_above_min("--session-timeout", session_timeout)
config["session"]["timeoutMilliseconds"] = session_timeout
if session_keep_alive:
error_msg += _assert_above_min("--session-keep-alive", session_keep_alive)
config["session"]["keepAliveIntervalMilliseconds"] = session_keep_alive
if session_reconnect_period:
error_msg += _assert_above_min("--session-reconnect-period", session_reconnect_period)
config["session"]["reconnectPeriodMilliseconds"] = session_reconnect_period
if session_reconnect_exponential_back_off:
error_msg += _assert_above_min("--session-reconnect-backoff", session_reconnect_exponential_back_off, -1)
config["session"]["reconnectExponentialBackOffMilliseconds"] = session_reconnect_exponential_back_off
# subscription
if any([sub_life_time, sub_max_items]) and not config.get("subscription"):
config["subscription"] = {}
if sub_life_time:
error_msg += _assert_above_min("--subscription-life-time", sub_life_time)
config["subscription"]["lifeTimeMilliseconds"] = sub_life_time
if sub_max_items:
error_msg += _assert_above_min("--subscription-max-items", sub_max_items, 1)
config["subscription"]["maxItems"] = sub_max_items
# security
if any([
auto_accept_untrusted_server_certs is not None, security_mode, security_policy
]) and not config.get("security"):
config["security"] = {}
if auto_accept_untrusted_server_certs is not None:
config["security"]["autoAcceptUntrustedServerCertificates"] = auto_accept_untrusted_server_certs
if security_mode:
config["security"]["securityMode"] = security_mode
if security_policy:
config["security"]["securityPolicy"] = "http://opcfoundation.org/UA/SecurityPolicy#" + security_policy
_raise_if_connector_error(connector_type="OPCUA", error_msg=error_msg)
return json.dumps(config)
def _build_query_body(
asset_endpoint_profile_name: Optional[str] = None,
auth_mode: Optional[str] = None,
endpoint_profile_type: Optional[str] = None,
location: Optional[str] = None,
resource_group_name: Optional[str] = None,
target_address: Optional[str] = None,
) -> str:
query_body = ""
if resource_group_name:
query_body += f"| where resourceGroup =~ \"{resource_group_name}\""
if location:
query_body += f"| where location =~ \"{location}\""
if asset_endpoint_profile_name:
query_body += f"| where name =~ \"{asset_endpoint_profile_name}\""
if auth_mode:
query_body += f"| where properties.authentication.method =~ \"{auth_mode}\""
if endpoint_profile_type:
query_body += f"| where properties.endpointProfileType =~ \"{endpoint_profile_type}\""
if target_address:
query_body += f"| where properties.targetAddress =~ \"{target_address}\""
query_body += "| extend customLocation = tostring(extendedLocation.name) "\
"| extend provisioningState = properties.provisioningState "\
"| project id, customLocation, location, name, resourceGroup, provisioningState, tags, "\
"type, subscriptionId "
return query_body
def _process_additional_configuration(configuration: str) -> Optional[str]:
from ....util import read_file_content
inline_json = False
if not configuration:
return
try:
logger.debug("Processing additional configuration.")
configuration = read_file_content(configuration)
if not configuration:
raise InvalidArgumentValueError("Given file is empty.")
except FileOperationError:
inline_json = True
logger.debug("Given additional configuration is not a file.")
# make sure it is an actual json
try:
json.loads(configuration)
return configuration
except json.JSONDecodeError as e:
error_msg = "Additional configuration is not a valid JSON. "
if inline_json:
error_msg += "For examples of valid JSON formating, please see https://aka.ms/inline-json-examples "
raise InvalidArgumentValueError(
f"{error_msg}\n{e.msg}"
)
def _process_authentication(
auth_mode: Optional[str] = None,
auth_props: Optional[Dict[str, str]] = None,
certificate_reference: Optional[str] = None,
password_reference: Optional[str] = None,
username_reference: Optional[str] = None
) -> Dict[str, str]:
if not auth_props:
auth_props = {}
# add checking for ensuring auth mode is set with proper params
if certificate_reference and (username_reference or password_reference):
raise MutuallyExclusiveArgumentError(AUTH_REF_MISMATCH_ERROR)
if certificate_reference and auth_mode in [None, AEPAuthModes.certificate.value]:
auth_props["method"] = AEPAuthModes.certificate.value
auth_props["x509Credentials"] = {"certificateSecretName": certificate_reference}
if auth_props.pop("usernamePasswordCredentials", None):
logger.warning(REMOVED_USERPASS_REF_MSG)
elif (username_reference or password_reference) and auth_mode in [None, AEPAuthModes.userpass.value]:
auth_props["method"] = AEPAuthModes.userpass.value
user_creds = auth_props.get("usernamePasswordCredentials", {})
user_creds["usernameSecretName"] = username_reference
user_creds["passwordSecretName"] = password_reference
if not all([user_creds["usernameSecretName"], user_creds["passwordSecretName"]]):
raise RequiredArgumentMissingError(MISSING_USERPASS_REF_ERROR)
auth_props["usernamePasswordCredentials"] = user_creds
if auth_props.pop("x509Credentials", None):
logger.warning(REMOVED_CERT_REF_MSG)
elif auth_mode == AEPAuthModes.anonymous.value and not any(
[certificate_reference, username_reference, password_reference]
):
auth_props["method"] = AEPAuthModes.anonymous.value
if auth_props.pop("x509Credentials", None):
logger.warning(REMOVED_CERT_REF_MSG)
if auth_props.pop("usernamePasswordCredentials", None):
logger.warning(REMOVED_USERPASS_REF_MSG)
elif any([auth_mode, certificate_reference, username_reference, password_reference]):
raise MutuallyExclusiveArgumentError(GENERAL_AUTH_REF_MISMATCH_ERROR)
return auth_props
def _update_properties(
properties,
target_address: Optional[str] = None,
additional_configuration: Optional[str] = None,
auth_mode: Optional[str] = None,
username_reference: Optional[str] = None,
password_reference: Optional[str] = None,
certificate_reference: Optional[str] = None,
):
if additional_configuration:
properties["additionalConfiguration"] = additional_configuration
if target_address:
properties["targetAddress"] = target_address
if any([auth_mode, username_reference, password_reference, certificate_reference]):
auth_props = properties.get("authentication", {})
properties["authentication"] = _process_authentication(
auth_props=auth_props,
auth_mode=auth_mode,
certificate_reference=certificate_reference,
username_reference=username_reference,
password_reference=password_reference
)