azext_iot/common/base_discovery.py (180 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from abc import ABC, abstractmethod
from azure.cli.core.azclierror import ResourceNotFoundError
from azure.core.exceptions import HttpResponseError
from knack.log import get_logger
from azext_iot.common.shared import AuthenticationTypeDataplane
from typing import Any, Dict, List
from types import SimpleNamespace
from azext_iot.common.utility import valid_hostname
logger = get_logger(__name__)
POLICY_ERROR_TEMPLATE = (
"Unable to discover a priviledged policy for {0}: {1}, in subscription {2}. "
"When interfacing with an {0}, the IoT extension requires any single policy with "
"{3} rights."
)
def _format_policy_set(inputs: set) -> str:
inputs = list(f"'{x}'" for x in inputs)
if len(inputs) == 1:
return inputs[0]
elif len(inputs) == 2:
return inputs[0] + " and " + inputs[1]
inputs[-1] = "and " + inputs[-1]
return ", ".join(inputs)
# Abstract base class
class BaseDiscovery(ABC):
"""BaseDiscovery to support resource and policy auto discovery.
Eliminates the need to provide the resource group and policy name to
find a specific target resource.
:ivar cmd: The cmd object
:vartype cmd:
:ivar client: The client object
:vartype client:
:ivar sub_id: Subscription id
:vartype sub_id: str
:ivar resource_type: Type of the resources the client fetches. Used to abstract
error messages.
:vartype resource_type: DiscoveryResourceType
:ivar necessary_rights_set: Set of policy names needed for the Iot Extension to run
commands against the DPS instance.
:vartype necessary_rights_set: Set[str]
"""
def __init__(self, cmd, necessary_rights_set: set = None, resource_type: str = None):
self.cmd = cmd
self.client = None
self.sub_id = "unknown"
self.resource_type = resource_type
self.necessary_rights_set = necessary_rights_set
@abstractmethod
def _initialize_client(self):
"""Creates the client if not created already."""
pass
@abstractmethod
def _make_kwargs(self, **kwargs) -> Dict[str, Any]:
"""Returns the correct kwargs for the client operations."""
pass
def get_resources(self, rg: str = None) -> List:
"""
Returns a list of all raw resources that are present within the subscription (and
resource group if provided).
The resources are the raw data returned by the client and will be used to build
target objects.
:param rg: Resource Group
:type rg: str
:return: List of resources
:rtype: List
"""
self._initialize_client()
resource_list = []
if not rg:
resource_pager = self.client.list_by_subscription()
else:
resource_pager = self.client.list_by_resource_group(resource_group_name=rg)
for resources in resource_pager.by_page():
resource_list.extend(resources)
return resource_list
def get_policies(self, resource_name: str, rg: str) -> List:
"""
Returns a list of all policies for a given resource in a given resource group.
:param resource_name: Resource Name
:type resource_name: str
:param rg: Resource Group
:type rg: str
:return: List of policies
:rtype: List
"""
self._initialize_client()
policy_pager = self.client.list_keys(
**self._make_kwargs(resource_name=resource_name, resource_group_name=rg)
)
policy_list = []
for policies in policy_pager.by_page():
policy_list.extend(policies)
return policy_list
def find_resource(self, resource_name: str, rg: str = None):
"""
Returns the resource with the given resource_name.
If the resource group is not provided, will look through all resources within the
subscription and return first match. This functionality will only work for
resource types that require unique names within the subscription.
Raises ResourceNotFoundError if no resource is found.
:param resource_name: Resource Name
:type resource_name: str
:param rg: Resource Group
:type rg: str
:return: Resource
:rtype: dict representing self.resource_type
"""
self._initialize_client()
if rg:
try:
return self.client.get(
**self._make_kwargs(
resource_name=resource_name, resource_group_name=rg
)
)
except Exception:
raise ResourceNotFoundError(
"Unable to find {}: {} in resource group: {}".format(
self.resource_type, resource_name, rg
)
)
resource_list = self.get_resources()
if resource_list:
target = next(
(resource for resource in resource_list if resource_name.lower() == resource.name.lower()),
None
)
if target:
return target
raise ResourceNotFoundError(
"Unable to find {}: {} in current subscription {}.".format(
self.resource_type, resource_name, self.sub_id
)
)
def find_policy(self, resource_name: str, rg: str, policy_name: str = "auto"):
"""
Returns the policy with the policy_name for the given resource.
If the policy name is not provided, will look through all policies for the given
resource and return the first usable policy (the first policy that the IoT
extension can use).
Raises ResourceNotFoundError if no usable policy is found.
:param resource_name: Resource Name
:type resource_name: str
:param rg: Resource Group
:type rg: str
:param policy_name: Policy Name
:type policy_name: str
:return: Policy
:rtype: policy
"""
self._initialize_client()
if policy_name.lower() != "auto":
return self.client.get_keys_for_key_name(
**self._make_kwargs(
resource_name=resource_name,
resource_group_name=rg,
key_name=policy_name
)
)
policy_list = self.get_policies(resource_name=resource_name, rg=rg)
for policy in policy_list:
rights_set = set(policy.rights.split(", "))
if self.necessary_rights_set.issubset(rights_set):
logger.info(
"Using policy '%s' for %s interaction.", policy.key_name, self.resource_type
)
return policy
raise ResourceNotFoundError(
POLICY_ERROR_TEMPLATE.format(
self.resource_type,
resource_name,
self.sub_id,
_format_policy_set(self.necessary_rights_set)
)
)
@classmethod
@abstractmethod
def get_target_by_cstring(cls, connection_string):
"""Returns target inforation needed from a connection string."""
pass
def get_target(
self, resource_name: str, resource_group_name: str = None, **kwargs
) -> Dict[str, str]:
"""
Returns a dictionary of the given resource's connection string parts to be used
by the extension.
This function finds the target resource and builds up a dictionary of connection
string parts needed for IoT extension operation. In future iteration we will
return a 'Target' object rather than dict but that will be better served aligning
with vNext pattern for Iot Hub/DPS.
If the resource group is not provided, will look through all resources within the
subscription and return first match. This functionality will only work for
resource types that require unique names within the subscription.
If the policy name is not provided, will look through all policies for the given
resource and return the first usable policy (the first policy that the IoT
extension can use).
Raises ResourceNotFoundError if no resource is found.
:param resource_name: Resource Name
:type resource_name: str
:param rg: Resource Group
:type rg: str
:keyword str login: Connection string for the target resource
:keyword str key_type: Key type to use in connection string construction
:keyword auth_type: Authentication Type for the Dataplane
:paramtype auth_type: AuthenticationTypeDataplane
:keyword str policy_name: Policy name to use
:return: Resource
:rtype: dict representing self.resource_type
"""
cstring = kwargs.get("login")
if cstring:
return self.get_target_by_cstring(connection_string=cstring)
resource_group_name = resource_group_name or kwargs.get("rg")
https_prefix = "https://"
http_prefix = "http://"
if resource_name.lower().startswith(https_prefix):
resource_name = resource_name[len(https_prefix) :]
elif resource_name.lower().startswith(http_prefix):
resource_name = resource_name[len(http_prefix) :]
auth_type = kwargs.get("auth_type", AuthenticationTypeDataplane.key.value)
if auth_type == AuthenticationTypeDataplane.login.value:
logger.info("Using AAD access token for %s interaction.", self.resource_type)
if all([not kwargs.get("force_find_resource"), valid_hostname(resource_name), "." in resource_name]):
return self._build_target_from_hostname(
resource_hostname=resource_name
)
resource = self.find_resource(resource_name=resource_name, rg=resource_group_name)
policy = SimpleNamespace()
policy.key_name = AuthenticationTypeDataplane.login.value
policy.primary_key = AuthenticationTypeDataplane.login.value
policy.secondary_key = AuthenticationTypeDataplane.login.value
return self._build_target(
resource=resource,
policy=policy,
key_type="primary",
**kwargs
)
if "." in resource_name:
resource_name = resource_name.split(".")[0]
resource = self.find_resource(resource_name=resource_name, rg=resource_group_name)
key_type = kwargs.get("key_type", "primary")
policy_name = kwargs.get("policy_name", "auto")
rg = resource.additional_properties.get("resourcegroup")
resource_policy = self.find_policy(
resource_name=resource.name, rg=rg, policy_name=policy_name,
)
return self._build_target(
resource=resource,
policy=resource_policy,
key_type=key_type,
**kwargs
)
def get_targets(self, resource_group_name: str = None, **kwargs) -> List[Dict[str, str]]:
"""
Returns a list of targets (dicts representing a resource's connection string parts)
that are usable by the extension within the subscription (and resource group if
provided).
:param rg: Resource Group
:type rg: str
:return: Resources
:rtype: list[dict]
"""
targets = []
resources = self.get_resources(rg=resource_group_name)
if resources:
for resource in resources:
try:
targets.append(
self.get_target(
resource_name=resource.name,
resource_group_name=resource.additional_properties.get("resourcegroup"),
**kwargs
)
)
except (HttpResponseError, ResourceNotFoundError) as e:
logger.warning("Could not access %s. %s", resource.name, e)
return targets
@abstractmethod
def _build_target_from_hostname(self, resource_hostname):
"""Returns target inforation needed from a hostname."""
pass
@abstractmethod
def _build_target(self, resource, policy, key_type=None, **kwargs):
"""Returns a dictionary representing the resource connection string parts to
be used by the IoT extension."""
pass