azext_iot/common/utility.py (475 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""
utility: Defines common utility functions and components.
"""
import ast
import base64
import isodate
import json
import os
import sys
import re
import hmac
import hashlib
from typing import Any, Optional, List, Dict
from threading import Event, Thread
from datetime import datetime
from knack.log import get_logger
from azure.cli.core.azclierror import (
CLIInternalError,
FileOperationError,
InvalidArgumentValueError,
)
logger = get_logger(__name__)
def parse_entity(entity, filter_none=False):
"""
Function creates a dict of object attributes.
Args:
entity (object): object to extract attributes from.
Returns:
result (dict): a dictionary of attributes from the function input.
"""
result = {}
attributes = [attr for attr in dir(entity) if not attr.startswith("_")]
for attribute in attributes:
value = getattr(entity, attribute, None)
if filter_none and not value:
continue
value_behavior = dir(value)
if "__call__" not in value_behavior:
result[attribute] = value
return result
def evaluate_literal(literal, expected):
"""
Function to provide safe evaluation of code literal.
Args:
literal (): code literal
expected (class, type, tuple): expected resulting class,
type or tuple of literal evaluation.
Returns:
result (string, number, tuple, list, dict, boolean, None).
"""
# Safe evaluation
try:
result = ast.literal_eval(literal)
if not isinstance(result, expected):
return None
return result
except Exception:
return None
def verify_transform(subject, mapping):
"""
Determines if a key from mapping exists in subject and if so
verifies that subject[k] is of type mapping[k]
"""
import jmespath
for k in mapping.keys():
result = jmespath.search(k, subject)
if result is None:
raise AttributeError('The property "{}" is required'.format(k))
if not isinstance(result, mapping[k]):
supplemental_info = ""
if mapping[k] == dict:
wiki_link = "https://github.com/Azure/azure-iot-cli-extension/wiki/Tips"
supplemental_info = "Review inline JSON examples here --> {}".format(
wiki_link
)
raise TypeError(
'The property "{}" must be of {} but is {}. Input: {}. {}'.format(
k, str(mapping[k]), str(type(result)), result, supplemental_info
)
)
def validate_key_value_pairs(string):
"""
Funtion to validate key-value pairs in the format: a=b;c=d
Args:
string (str): semicolon delimited string of key/value pairs.
Returns (dict, None): a dictionary of key value pairs.
"""
result = None
if string:
kv_list = [x for x in string.split(";") if "=" in x] # key-value pairs
result = dict(x.split("=", 1) for x in kv_list)
return result
def process_json_arg(
content: str, argument_name: str = "content", preserve_order=False
):
"""Primary processor of json input"""
json_from_file = None
if os.path.exists(content):
json_from_file = content
content = read_file_content(content)
try:
return shell_safe_json_parse(content, preserve_order)
except CLIInternalError as ex:
if looks_like_file(content):
logger.warning(
"The json payload for argument '%s' looks like its intended from a file. "
"Please ensure the file path is correct.",
argument_name,
)
file_content_error = "from file: '{}' ".format(json_from_file)
raise CLIInternalError(
"Failed to parse json {}for argument '{}' with exception:\n {}".format(
file_content_error if json_from_file else "", argument_name, ex
)
)
_file_location_error = "{0} file not found - Please ensure the path '{1}' is correct."
_file_parse_error = "Failed to parse {0} file located at '{1}' with exception:\n{2}"
def process_yaml_arg(path: str) -> Dict[str, Any]:
"""Primary processor of yaml file input"""
if not os.path.exists(path):
raise FileOperationError(
_file_location_error.format("YAML", path)
)
try:
import yaml
with open(path, "rb") as f:
return yaml.load(f, Loader=yaml.SafeLoader)
except Exception as ex:
raise InvalidArgumentValueError(
_file_parse_error.format("YAML", path, ex)
)
def process_toml_arg(path: str) -> Dict[str, Any]:
"""Primary processor of TOML file input"""
if not os.path.exists(path):
raise FileOperationError(
_file_location_error.format("TOML", path)
)
try:
import tomli
with open(path, "rb") as f:
return tomli.load(f)
except Exception as ex:
raise InvalidArgumentValueError(
_file_parse_error.format("TOML", path, ex)
)
def shell_safe_json_parse(json_or_dict_string, preserve_order=False):
"""Allows the passing of JSON or Python dictionary strings. This is needed because certain
JSON strings in CMD shell are not received in main's argv. This allows the user to specify
the alternative notation, which does not have this problem (but is technically not JSON)."""
try:
if not preserve_order:
return json.loads(json_or_dict_string)
from collections import OrderedDict
return json.loads(json_or_dict_string, object_pairs_hook=OrderedDict)
except ValueError as json_ex:
try:
return ast.literal_eval(json_or_dict_string)
except SyntaxError:
raise CLIInternalError(json_ex)
except ValueError as ex:
logger.debug(
ex
) # log the exception which could be a python dict parsing error.
raise CLIInternalError(
json_ex
) # raise json_ex error which is more readable and likely.
def read_file_content(file_path, allow_binary=False):
from codecs import open as codecs_open
# Note, always put 'utf-8-sig' first, so that BOM in WinOS won't cause trouble.
for encoding in ["utf-8-sig", "utf-8", "utf-16", "utf-16le", "utf-16be"]:
try:
with codecs_open(file_path, encoding=encoding) as f:
logger.debug("Attempting to read file %s as %s", file_path, encoding)
return f.read()
except (UnicodeError, UnicodeDecodeError):
pass
if allow_binary:
try:
with open(file_path, "rb") as input_file:
logger.debug("Attempting to read file %s as binary", file_path)
return base64.b64encode(input_file.read()).decode("utf-8")
except Exception: # pylint: disable=broad-except
pass
raise FileOperationError(
"Failed to decode file {} - unknown decoding".format(file_path)
)
def trim_from_start(s, substring):
"""Trims a substring from the target string (if it exists) returning the trimmed string.
Otherwise returns original target string."""
if s.startswith(substring):
s = s[len(substring) :]
return s
def validate_min_python_version(major, minor, error_msg=None, exit_on_fail=True):
"""If python version does not match AT LEAST requested values, will throw non 0 exit code."""
version = sys.version_info
result = False
if version.major > major:
return True
if major == version.major:
result = version.minor >= minor
if not result:
if exit_on_fail:
msg = (
error_msg
if error_msg
else "Python version {}.{} or higher required for this functionality.".format(
major, minor
)
)
sys.exit(msg)
return result
def unicode_binary_map(target):
"""Decode binary keys and values of map to unicode."""
# Assumes no iteritems()
result = {}
for k in target:
key = k
if isinstance(k, bytes):
key = str(k, "utf8")
if isinstance(target[k], bytes):
result[key] = str(target[k], "utf8")
else:
result[key] = target[k]
return result
def execute_onthread(**kwargs):
"""
Experimental generic helper for executing methods without return values on a background thread
Args:
kwargs: Supported kwargs are 'interval' (int) to specify intervals between calls
'method' (func) to specify method pointer for execution
'args' (dict) to specify method arguments
'max_runs' (int) indicate an upper bound on number of executions
'return_handle' (bool) indicates whether to return a Thread handle
Returns:
Event(): Object to set the event cancellation flag
or if 'return_handle'=True
Event(), Thread(): Event object to set the cancellation flag, Executing Thread object
"""
interval = kwargs.get("interval")
method = kwargs.get("method")
method_args = kwargs.get("args")
max_runs = kwargs.get("max_runs")
handle = kwargs.get("return_handle")
if not interval:
interval = 2
if not method:
raise ValueError('kwarg "method" required for execution')
if not method_args:
method_args = []
cancellation_token = Event()
def method_wrap(max_runs=None):
runs = 0
while not cancellation_token.wait(interval):
if max_runs:
if runs >= max_runs:
break
method(**method_args)
runs += 1
op = Thread(target=method_wrap, args=(max_runs,))
op.start()
if handle:
return cancellation_token, op
return cancellation_token
def url_encode_dict(d):
try:
from urllib import urlencode
except ImportError:
from urllib.parse import urlencode
return urlencode(d)
def url_encode_str(s, plus=False):
try:
if plus:
from urllib import quote_plus
else:
from urllib import quote
except ImportError:
if plus:
from urllib.parse import quote_plus
else:
from urllib.parse import quote
return quote_plus(s) if plus else quote(s)
def test_import_and_version(package, expected_version):
"""
Used to determine if a package dependency, installed with metadata,
is at least the expected version. This utility will not work for packages
that are installed without metadata.
"""
from importlib import metadata
from packaging.version import parse
try:
return parse(metadata.version(package)) >= parse(expected_version)
except metadata.PackageNotFoundError:
return False
def unpack_pnp_http_error(e):
error = unpack_msrest_error(e)
if isinstance(error, dict):
if error.get("error"):
error = error["error"]
if error.get("stackTrace"):
error.pop("stackTrace")
return error
def unpack_msrest_error(e):
"""Obtains full response text from an msrest error"""
from typing import Callable
op_err = None
try:
err_txt = ""
if isinstance(e.response.text, Callable):
err_txt = e.response.text()
else:
err_txt = e.response.text
op_err = json.loads(err_txt)
except (ValueError, TypeError):
op_err = err_txt
if not op_err:
return str(e)
return op_err
def handle_service_exception(e):
"""
Used to unpack service error messages and status codes,
and raise the correct azclierror class.
For more info on CLI error handling guidelines, see
https://github.com/Azure/azure-cli/blob/dev/doc/error_handling_guidelines.md
"""
from azure.cli.core.azclierror import (
AzureInternalError,
AzureResponseError,
BadRequestError,
ForbiddenError,
ResourceNotFoundError,
UnauthorizedError,
)
err = unpack_msrest_error(e)
op_status = getattr(e.response, "status_code", -1)
# Generic error if the status_code is explicitly None
if not op_status:
raise AzureResponseError(err)
if op_status == 400:
raise BadRequestError(err)
if op_status == 401:
raise UnauthorizedError(err)
if op_status == 403:
raise ForbiddenError(err)
if op_status == 404:
raise ResourceNotFoundError(err)
# Any 5xx error should throw an AzureInternalError
if 500 <= op_status < 600:
raise AzureInternalError(err)
# Otherwise, fail with generic service error
raise AzureResponseError(err)
def dict_transform_lower_case_key(d):
"""Converts a dictionary to an identical one with all lower case keys"""
return {k.lower(): v for k, v in d.items()}
def calculate_millisec_since_unix_epoch_utc(offset_seconds: int = 0):
now = datetime.utcnow()
epoch = datetime.utcfromtimestamp(0)
return int(1000 * ((now - epoch).total_seconds() + offset_seconds))
def init_monitoring(
cmd,
timeout,
properties,
enqueued_time,
repair,
yes,
message_count: Optional[int] = None,
):
from azext_iot.common.deps import ensure_uamqp
if timeout < 0:
raise InvalidArgumentValueError(
"Monitoring timeout must be 0 (inf) or greater."
)
timeout = timeout * 1000
if message_count and message_count <= 0:
raise InvalidArgumentValueError("Message count must be greater than 0.")
config = cmd.cli_ctx.config
output = cmd.cli_ctx.invocation.data.get("output", None)
if not output:
output = "json"
ensure_uamqp(config, yes, repair)
if not properties:
properties = []
properties = set((key.lower() for key in properties))
if not enqueued_time:
enqueued_time = calculate_millisec_since_unix_epoch_utc()
return (enqueued_time, properties, timeout, output, message_count)
def dict_clean(d):
"""Remove None from dictionary"""
if not isinstance(d, dict):
return d
return dict((k, dict_clean(v)) for k, v in d.items() if v is not None)
def looks_like_file(element):
element = element.lower()
return element.endswith(
(
".log",
".rtf",
".txt",
".json",
".yaml",
".yml",
".md",
".rst",
".doc",
".docx",
".html",
".htm",
".py",
".java",
".ts",
".js",
".cs",
)
)
class ISO8601Validator:
def is_iso8601_date(self, to_validate) -> bool:
try:
return bool(isodate.parse_date(to_validate))
except Exception:
return False
def is_iso8601_datetime(self, to_validate: str) -> bool:
try:
return bool(isodate.parse_datetime(to_validate))
except Exception:
return False
def is_iso8601_duration(self, to_validate: str) -> bool:
try:
return bool(isodate.parse_duration(to_validate))
except Exception:
return False
def is_iso8601_time(self, to_validate: str) -> bool:
try:
return bool(isodate.parse_time(to_validate))
except Exception:
return False
def ensure_iothub_sdk_min_version(min_ver):
from packaging import version
try:
from azure.mgmt.iothub import __version__ as iot_sdk_version
except ImportError:
from azure.mgmt.iothub._version import VERSION as iot_sdk_version
return version.parse(iot_sdk_version) >= version.parse(min_ver)
def ensure_iotdps_sdk_min_version(min_ver):
from packaging import version
try:
from azure.mgmt.iothubprovisioningservices import __version__ as iot_sdk_version
except ImportError:
from azure.mgmt.iothubprovisioningservices._version import (
VERSION as iot_sdk_version,
)
return version.parse(iot_sdk_version) >= version.parse(min_ver)
def scantree(path):
for entry in os.scandir(path):
if entry.is_dir(follow_symlinks=False):
yield from scantree(entry.path)
else:
yield entry
def find_between(s, start, end):
return (s.split(start))[1].split(end)[0]
def valid_hostname(host_name):
"""
Approximate validation
Reference: https://en.wikipedia.org/wiki/Hostname
"""
if len(host_name) > 253:
return False
valid_label = re.compile(r"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
label_parts = host_name.split(".")
return all(valid_label.match(label) for label in label_parts)
def compute_device_key(primary_key, registration_id):
"""
Compute device SAS key
Args:
primary_key: Primary group SAS token to compute device keys
registration_id: Registration ID is alphanumeric, lowercase, and may contain hyphens.
Returns:
device key
"""
secret = base64.b64decode(primary_key)
device_key = base64.b64encode(
hmac.new(
secret, msg=registration_id.encode("utf8"), digestmod=hashlib.sha256
).digest()
)
return device_key
def generate_key(byte_length=32):
"""
Generate cryptographically secure device key.
"""
import secrets
token_bytes = secrets.token_bytes(byte_length)
return base64.b64encode(token_bytes).decode("utf8")
def ensure_azure_namespace_path():
"""
Run prior to importing azure namespace packages (azure.*) to ensure the
extension root path is configured for package import.
"""
from azure.cli.core.extension import get_extension_path
from azext_iot.constants import EXTENSION_NAME
ext_path = get_extension_path(EXTENSION_NAME)
if not ext_path:
return
ext_azure_dir = os.path.join(ext_path, "azure")
if os.path.isdir(ext_azure_dir):
import azure
if (
getattr(azure, "__path__", None) and ext_azure_dir not in azure.__path__
): # _NamespacePath /w PEP420
if isinstance(azure.__path__, list):
azure.__path__.insert(0, ext_azure_dir)
else:
azure.__path__.append(ext_azure_dir)
if sys.path and sys.path[0] != ext_path:
sys.path.insert(0, ext_path)
return
def is_valid_dtmi(dtmi):
"""Checks validity of a DTMI
:param str dtmi: DTMI
:returns: Boolean indicating if DTMI is valid
:rtype: bool
"""
pattern = re.compile(
"^dtmi:[A-Za-z](?:[A-Za-z0-9_]*[A-Za-z0-9])?(?::[A-Za-z](?:[A-Za-z0-9_]*[A-Za-z0-9])?)*;[1-9][0-9]{0,8}$"
)
if not pattern.match(dtmi):
return False
return True
def generate_storage_account_sas_token(
storage_cstring: str,
expiry_in_hours: int = 1,
read: bool = False,
write: bool = False,
create: bool = False,
update: bool = False,
add: bool = False,
list: bool = False,
delete: bool = False,
):
from datetime import datetime, timedelta
ensure_azure_namespace_path()
from azure.storage.blob import ResourceTypes, AccountSasPermissions, generate_account_sas, BlobServiceClient
blob_service_client = BlobServiceClient.from_connection_string(conn_str=storage_cstring)
sas_token = generate_account_sas(
blob_service_client.account_name,
account_key=blob_service_client.credential.account_key,
resource_types=ResourceTypes(object=True),
permission=AccountSasPermissions(
read=read, write=write, create=create, update=update, add=add, list=list, delete=delete
),
expiry=datetime.utcnow() + timedelta(hours=expiry_in_hours)
)
return sas_token
def assemble_nargs_to_dict(hash_list: List[str]) -> Dict[str, str]:
result = {}
if not hash_list:
return result
for hash in hash_list:
if "=" not in hash:
logger.warning(
"Skipping processing of '%s', input format is key=value | key='value value'.",
hash,
)
continue
split_hash = hash.split("=", 1)
result[split_hash[0]] = split_hash[1]
for key in result:
if not result.get(key):
logger.warning(
"No value assigned to key '%s', input format is key=value | key='value value'.",
key,
)
return result