azext_iot/common/utility.py (475 lines of code) (raw):

# coding=utf-8 # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- """ utility: Defines common utility functions and components. """ import ast import base64 import isodate import json import os import sys import re import hmac import hashlib from typing import Any, Optional, List, Dict from threading import Event, Thread from datetime import datetime from knack.log import get_logger from azure.cli.core.azclierror import ( CLIInternalError, FileOperationError, InvalidArgumentValueError, ) logger = get_logger(__name__) def parse_entity(entity, filter_none=False): """ Function creates a dict of object attributes. Args: entity (object): object to extract attributes from. Returns: result (dict): a dictionary of attributes from the function input. """ result = {} attributes = [attr for attr in dir(entity) if not attr.startswith("_")] for attribute in attributes: value = getattr(entity, attribute, None) if filter_none and not value: continue value_behavior = dir(value) if "__call__" not in value_behavior: result[attribute] = value return result def evaluate_literal(literal, expected): """ Function to provide safe evaluation of code literal. Args: literal (): code literal expected (class, type, tuple): expected resulting class, type or tuple of literal evaluation. Returns: result (string, number, tuple, list, dict, boolean, None). """ # Safe evaluation try: result = ast.literal_eval(literal) if not isinstance(result, expected): return None return result except Exception: return None def verify_transform(subject, mapping): """ Determines if a key from mapping exists in subject and if so verifies that subject[k] is of type mapping[k] """ import jmespath for k in mapping.keys(): result = jmespath.search(k, subject) if result is None: raise AttributeError('The property "{}" is required'.format(k)) if not isinstance(result, mapping[k]): supplemental_info = "" if mapping[k] == dict: wiki_link = "https://github.com/Azure/azure-iot-cli-extension/wiki/Tips" supplemental_info = "Review inline JSON examples here --> {}".format( wiki_link ) raise TypeError( 'The property "{}" must be of {} but is {}. Input: {}. {}'.format( k, str(mapping[k]), str(type(result)), result, supplemental_info ) ) def validate_key_value_pairs(string): """ Funtion to validate key-value pairs in the format: a=b;c=d Args: string (str): semicolon delimited string of key/value pairs. Returns (dict, None): a dictionary of key value pairs. """ result = None if string: kv_list = [x for x in string.split(";") if "=" in x] # key-value pairs result = dict(x.split("=", 1) for x in kv_list) return result def process_json_arg( content: str, argument_name: str = "content", preserve_order=False ): """Primary processor of json input""" json_from_file = None if os.path.exists(content): json_from_file = content content = read_file_content(content) try: return shell_safe_json_parse(content, preserve_order) except CLIInternalError as ex: if looks_like_file(content): logger.warning( "The json payload for argument '%s' looks like its intended from a file. " "Please ensure the file path is correct.", argument_name, ) file_content_error = "from file: '{}' ".format(json_from_file) raise CLIInternalError( "Failed to parse json {}for argument '{}' with exception:\n {}".format( file_content_error if json_from_file else "", argument_name, ex ) ) _file_location_error = "{0} file not found - Please ensure the path '{1}' is correct." _file_parse_error = "Failed to parse {0} file located at '{1}' with exception:\n{2}" def process_yaml_arg(path: str) -> Dict[str, Any]: """Primary processor of yaml file input""" if not os.path.exists(path): raise FileOperationError( _file_location_error.format("YAML", path) ) try: import yaml with open(path, "rb") as f: return yaml.load(f, Loader=yaml.SafeLoader) except Exception as ex: raise InvalidArgumentValueError( _file_parse_error.format("YAML", path, ex) ) def process_toml_arg(path: str) -> Dict[str, Any]: """Primary processor of TOML file input""" if not os.path.exists(path): raise FileOperationError( _file_location_error.format("TOML", path) ) try: import tomli with open(path, "rb") as f: return tomli.load(f) except Exception as ex: raise InvalidArgumentValueError( _file_parse_error.format("TOML", path, ex) ) def shell_safe_json_parse(json_or_dict_string, preserve_order=False): """Allows the passing of JSON or Python dictionary strings. This is needed because certain JSON strings in CMD shell are not received in main's argv. This allows the user to specify the alternative notation, which does not have this problem (but is technically not JSON).""" try: if not preserve_order: return json.loads(json_or_dict_string) from collections import OrderedDict return json.loads(json_or_dict_string, object_pairs_hook=OrderedDict) except ValueError as json_ex: try: return ast.literal_eval(json_or_dict_string) except SyntaxError: raise CLIInternalError(json_ex) except ValueError as ex: logger.debug( ex ) # log the exception which could be a python dict parsing error. raise CLIInternalError( json_ex ) # raise json_ex error which is more readable and likely. def read_file_content(file_path, allow_binary=False): from codecs import open as codecs_open # Note, always put 'utf-8-sig' first, so that BOM in WinOS won't cause trouble. for encoding in ["utf-8-sig", "utf-8", "utf-16", "utf-16le", "utf-16be"]: try: with codecs_open(file_path, encoding=encoding) as f: logger.debug("Attempting to read file %s as %s", file_path, encoding) return f.read() except (UnicodeError, UnicodeDecodeError): pass if allow_binary: try: with open(file_path, "rb") as input_file: logger.debug("Attempting to read file %s as binary", file_path) return base64.b64encode(input_file.read()).decode("utf-8") except Exception: # pylint: disable=broad-except pass raise FileOperationError( "Failed to decode file {} - unknown decoding".format(file_path) ) def trim_from_start(s, substring): """Trims a substring from the target string (if it exists) returning the trimmed string. Otherwise returns original target string.""" if s.startswith(substring): s = s[len(substring) :] return s def validate_min_python_version(major, minor, error_msg=None, exit_on_fail=True): """If python version does not match AT LEAST requested values, will throw non 0 exit code.""" version = sys.version_info result = False if version.major > major: return True if major == version.major: result = version.minor >= minor if not result: if exit_on_fail: msg = ( error_msg if error_msg else "Python version {}.{} or higher required for this functionality.".format( major, minor ) ) sys.exit(msg) return result def unicode_binary_map(target): """Decode binary keys and values of map to unicode.""" # Assumes no iteritems() result = {} for k in target: key = k if isinstance(k, bytes): key = str(k, "utf8") if isinstance(target[k], bytes): result[key] = str(target[k], "utf8") else: result[key] = target[k] return result def execute_onthread(**kwargs): """ Experimental generic helper for executing methods without return values on a background thread Args: kwargs: Supported kwargs are 'interval' (int) to specify intervals between calls 'method' (func) to specify method pointer for execution 'args' (dict) to specify method arguments 'max_runs' (int) indicate an upper bound on number of executions 'return_handle' (bool) indicates whether to return a Thread handle Returns: Event(): Object to set the event cancellation flag or if 'return_handle'=True Event(), Thread(): Event object to set the cancellation flag, Executing Thread object """ interval = kwargs.get("interval") method = kwargs.get("method") method_args = kwargs.get("args") max_runs = kwargs.get("max_runs") handle = kwargs.get("return_handle") if not interval: interval = 2 if not method: raise ValueError('kwarg "method" required for execution') if not method_args: method_args = [] cancellation_token = Event() def method_wrap(max_runs=None): runs = 0 while not cancellation_token.wait(interval): if max_runs: if runs >= max_runs: break method(**method_args) runs += 1 op = Thread(target=method_wrap, args=(max_runs,)) op.start() if handle: return cancellation_token, op return cancellation_token def url_encode_dict(d): try: from urllib import urlencode except ImportError: from urllib.parse import urlencode return urlencode(d) def url_encode_str(s, plus=False): try: if plus: from urllib import quote_plus else: from urllib import quote except ImportError: if plus: from urllib.parse import quote_plus else: from urllib.parse import quote return quote_plus(s) if plus else quote(s) def test_import_and_version(package, expected_version): """ Used to determine if a package dependency, installed with metadata, is at least the expected version. This utility will not work for packages that are installed without metadata. """ from importlib import metadata from packaging.version import parse try: return parse(metadata.version(package)) >= parse(expected_version) except metadata.PackageNotFoundError: return False def unpack_pnp_http_error(e): error = unpack_msrest_error(e) if isinstance(error, dict): if error.get("error"): error = error["error"] if error.get("stackTrace"): error.pop("stackTrace") return error def unpack_msrest_error(e): """Obtains full response text from an msrest error""" from typing import Callable op_err = None try: err_txt = "" if isinstance(e.response.text, Callable): err_txt = e.response.text() else: err_txt = e.response.text op_err = json.loads(err_txt) except (ValueError, TypeError): op_err = err_txt if not op_err: return str(e) return op_err def handle_service_exception(e): """ Used to unpack service error messages and status codes, and raise the correct azclierror class. For more info on CLI error handling guidelines, see https://github.com/Azure/azure-cli/blob/dev/doc/error_handling_guidelines.md """ from azure.cli.core.azclierror import ( AzureInternalError, AzureResponseError, BadRequestError, ForbiddenError, ResourceNotFoundError, UnauthorizedError, ) err = unpack_msrest_error(e) op_status = getattr(e.response, "status_code", -1) # Generic error if the status_code is explicitly None if not op_status: raise AzureResponseError(err) if op_status == 400: raise BadRequestError(err) if op_status == 401: raise UnauthorizedError(err) if op_status == 403: raise ForbiddenError(err) if op_status == 404: raise ResourceNotFoundError(err) # Any 5xx error should throw an AzureInternalError if 500 <= op_status < 600: raise AzureInternalError(err) # Otherwise, fail with generic service error raise AzureResponseError(err) def dict_transform_lower_case_key(d): """Converts a dictionary to an identical one with all lower case keys""" return {k.lower(): v for k, v in d.items()} def calculate_millisec_since_unix_epoch_utc(offset_seconds: int = 0): now = datetime.utcnow() epoch = datetime.utcfromtimestamp(0) return int(1000 * ((now - epoch).total_seconds() + offset_seconds)) def init_monitoring( cmd, timeout, properties, enqueued_time, repair, yes, message_count: Optional[int] = None, ): from azext_iot.common.deps import ensure_uamqp if timeout < 0: raise InvalidArgumentValueError( "Monitoring timeout must be 0 (inf) or greater." ) timeout = timeout * 1000 if message_count and message_count <= 0: raise InvalidArgumentValueError("Message count must be greater than 0.") config = cmd.cli_ctx.config output = cmd.cli_ctx.invocation.data.get("output", None) if not output: output = "json" ensure_uamqp(config, yes, repair) if not properties: properties = [] properties = set((key.lower() for key in properties)) if not enqueued_time: enqueued_time = calculate_millisec_since_unix_epoch_utc() return (enqueued_time, properties, timeout, output, message_count) def dict_clean(d): """Remove None from dictionary""" if not isinstance(d, dict): return d return dict((k, dict_clean(v)) for k, v in d.items() if v is not None) def looks_like_file(element): element = element.lower() return element.endswith( ( ".log", ".rtf", ".txt", ".json", ".yaml", ".yml", ".md", ".rst", ".doc", ".docx", ".html", ".htm", ".py", ".java", ".ts", ".js", ".cs", ) ) class ISO8601Validator: def is_iso8601_date(self, to_validate) -> bool: try: return bool(isodate.parse_date(to_validate)) except Exception: return False def is_iso8601_datetime(self, to_validate: str) -> bool: try: return bool(isodate.parse_datetime(to_validate)) except Exception: return False def is_iso8601_duration(self, to_validate: str) -> bool: try: return bool(isodate.parse_duration(to_validate)) except Exception: return False def is_iso8601_time(self, to_validate: str) -> bool: try: return bool(isodate.parse_time(to_validate)) except Exception: return False def ensure_iothub_sdk_min_version(min_ver): from packaging import version try: from azure.mgmt.iothub import __version__ as iot_sdk_version except ImportError: from azure.mgmt.iothub._version import VERSION as iot_sdk_version return version.parse(iot_sdk_version) >= version.parse(min_ver) def ensure_iotdps_sdk_min_version(min_ver): from packaging import version try: from azure.mgmt.iothubprovisioningservices import __version__ as iot_sdk_version except ImportError: from azure.mgmt.iothubprovisioningservices._version import ( VERSION as iot_sdk_version, ) return version.parse(iot_sdk_version) >= version.parse(min_ver) def scantree(path): for entry in os.scandir(path): if entry.is_dir(follow_symlinks=False): yield from scantree(entry.path) else: yield entry def find_between(s, start, end): return (s.split(start))[1].split(end)[0] def valid_hostname(host_name): """ Approximate validation Reference: https://en.wikipedia.org/wiki/Hostname """ if len(host_name) > 253: return False valid_label = re.compile(r"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) label_parts = host_name.split(".") return all(valid_label.match(label) for label in label_parts) def compute_device_key(primary_key, registration_id): """ Compute device SAS key Args: primary_key: Primary group SAS token to compute device keys registration_id: Registration ID is alphanumeric, lowercase, and may contain hyphens. Returns: device key """ secret = base64.b64decode(primary_key) device_key = base64.b64encode( hmac.new( secret, msg=registration_id.encode("utf8"), digestmod=hashlib.sha256 ).digest() ) return device_key def generate_key(byte_length=32): """ Generate cryptographically secure device key. """ import secrets token_bytes = secrets.token_bytes(byte_length) return base64.b64encode(token_bytes).decode("utf8") def ensure_azure_namespace_path(): """ Run prior to importing azure namespace packages (azure.*) to ensure the extension root path is configured for package import. """ from azure.cli.core.extension import get_extension_path from azext_iot.constants import EXTENSION_NAME ext_path = get_extension_path(EXTENSION_NAME) if not ext_path: return ext_azure_dir = os.path.join(ext_path, "azure") if os.path.isdir(ext_azure_dir): import azure if ( getattr(azure, "__path__", None) and ext_azure_dir not in azure.__path__ ): # _NamespacePath /w PEP420 if isinstance(azure.__path__, list): azure.__path__.insert(0, ext_azure_dir) else: azure.__path__.append(ext_azure_dir) if sys.path and sys.path[0] != ext_path: sys.path.insert(0, ext_path) return def is_valid_dtmi(dtmi): """Checks validity of a DTMI :param str dtmi: DTMI :returns: Boolean indicating if DTMI is valid :rtype: bool """ pattern = re.compile( "^dtmi:[A-Za-z](?:[A-Za-z0-9_]*[A-Za-z0-9])?(?::[A-Za-z](?:[A-Za-z0-9_]*[A-Za-z0-9])?)*;[1-9][0-9]{0,8}$" ) if not pattern.match(dtmi): return False return True def generate_storage_account_sas_token( storage_cstring: str, expiry_in_hours: int = 1, read: bool = False, write: bool = False, create: bool = False, update: bool = False, add: bool = False, list: bool = False, delete: bool = False, ): from datetime import datetime, timedelta ensure_azure_namespace_path() from azure.storage.blob import ResourceTypes, AccountSasPermissions, generate_account_sas, BlobServiceClient blob_service_client = BlobServiceClient.from_connection_string(conn_str=storage_cstring) sas_token = generate_account_sas( blob_service_client.account_name, account_key=blob_service_client.credential.account_key, resource_types=ResourceTypes(object=True), permission=AccountSasPermissions( read=read, write=write, create=create, update=update, add=add, list=list, delete=delete ), expiry=datetime.utcnow() + timedelta(hours=expiry_in_hours) ) return sas_token def assemble_nargs_to_dict(hash_list: List[str]) -> Dict[str, str]: result = {} if not hash_list: return result for hash in hash_list: if "=" not in hash: logger.warning( "Skipping processing of '%s', input format is key=value | key='value value'.", hash, ) continue split_hash = hash.split("=", 1) result[split_hash[0]] = split_hash[1] for key in result: if not result.get(key): logger.warning( "No value assigned to key '%s', input format is key=value | key='value value'.", key, ) return result