azext_edge/edge/util/common.py (182 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- """ common: Defines common utility functions and components. """ import base64 import json import logging import re from typing import Dict, List, Optional from knack.log import get_logger logger = get_logger(__name__) def parse_kvp_nargs(kvp_nargs: List[str]) -> dict: """ Parses kvp nargs into a dict handling values of null and empty string. """ result = {} if not kvp_nargs: return result for item in kvp_nargs: key, sep, value = item.partition("=") result[key] = value if sep else None return result def assemble_nargs_to_dict(hash_list: List[str]) -> Dict[str, str]: result = {} if not hash_list: return result for hash in hash_list: if "=" not in hash: logger.warning( "Skipping processing of '%s', input format is key=value | key='value value'.", hash, ) continue split_hash = hash.split("=", 1) result[split_hash[0]] = split_hash[1] for key in result: if not result.get(key): logger.warning( "No value assigned to key '%s', input format is key=value | key='value value'.", key, ) return result def build_query(cmd, subscription_id: Optional[str] = None, custom_query: Optional[str] = None, **kwargs): url = "/providers/Microsoft.ResourceGraph/resources?api-version=2022-10-01" subscriptions = [subscription_id] if subscription_id else [] payload = {"subscriptions": subscriptions, "query": "Resources ", "options": {}} # TODO: add more query options as they pop up if kwargs.get("name"): payload["query"] += f'| where name =~ "{kwargs.get("name")}" ' if kwargs.get("resource_group"): payload["query"] += f'| where resourceGroup =~ "{kwargs.get("resource_group")}" ' if kwargs.get("location"): payload["query"] += f'| where location =~ "{kwargs.get("location")}" ' if kwargs.get("type"): payload["query"] += f'| where type =~ "{kwargs.get("type")}" ' if custom_query: payload["query"] += custom_query payload["query"] += "| project id, location, name, resourceGroup, properties, tags, type, subscriptionId" if kwargs.get("additional_project"): payload["query"] += f', {kwargs.get("additional_project")}' return _process_raw_request(cmd, url, "POST", payload) def _process_raw_request(cmd, url: str, method: str, payload: Optional[dict] = None, keyword: str = "data"): # since I don't want to download the resourcegraph sdk - we are stuck with this # note that we are trying to limit dependencies from azure.cli.core.util import send_raw_request result = [] skip_token = "sentinel" while skip_token: try: body = json.dumps(payload) if payload is not None else None res = send_raw_request(cli_ctx=cmd.cli_ctx, url=url, method=method, body=body) except Exception as e: raise e if not res.content: return json_response = res.json() result.extend(json_response[keyword]) skip_token = json_response.get("$skipToken") if skip_token: if not payload: payload = {"options": {}} if "options" not in payload: payload["options"] = {} payload["options"]["$skipToken"] = skip_token return result def get_timestamp_now_utc(format: str = "%Y-%m-%dT%H:%M:%S") -> str: from datetime import datetime, timezone timestamp = datetime.now(timezone.utc).strftime(format) return timestamp def set_log_level(log_name: str, log_level: int = logging.DEBUG): lgr = logging.getLogger(log_name) lgr.setLevel(log_level) def generate_secret(byte_length=32): """ Generate cryptographically secure secret. """ import secrets token_bytes = secrets.token_bytes(byte_length) return base64.b64encode(token_bytes).decode("utf8") def url_safe_hash_phrase(phrase: str) -> str: from hashlib import sha256 return sha256(phrase.encode("utf8")).hexdigest() def url_safe_random_chars(count: int) -> str: import secrets token = "" while len(token) < count: _t = secrets.token_urlsafe() _t = _t.replace("-", "") _t = _t.replace("_", "") token += _t return token[:count] def ensure_azure_namespace_path(): """ Run prior to importing azure namespace packages (azure.*) to ensure the extension root path is configured for package import. """ import os import sys from azure.cli.core.extension import get_extension_path from ...constants import EXTENSION_NAME ext_path = get_extension_path(EXTENSION_NAME) if not ext_path: return ext_azure_dir = os.path.join(ext_path, "azure") if os.path.isdir(ext_azure_dir): import azure if getattr(azure, "__path__", None) and ext_azure_dir not in azure.__path__: # _NamespacePath /w PEP420 if isinstance(azure.__path__, list): azure.__path__.insert(0, ext_azure_dir) else: azure.__path__.append(ext_azure_dir) if sys.path and sys.path[0] != ext_path: sys.path.insert(0, ext_path) def run_host_command(command: str, shell_mode: bool = True): from shlex import quote, split from subprocess import run if not command: raise ValueError("command value is required.") logger.debug("Running host command: %s", command) command = quote(command) split_command = split(command) try: return run(split_command, capture_output=True, check=False, shell=shell_mode) except FileNotFoundError: pass def is_env_flag_enabled(env_flag_key: str) -> bool: from os import getenv return getenv(env_flag_key, "false").lower() in ["true", "1", "y"] def is_enabled_str(value: Optional[str]) -> bool: """ Converts an intended property str value (such as from k8s enum) to bool """ if isinstance(value, str): return value.lower() == "enabled" return False def should_continue_prompt(confirm_yes: Optional[bool] = None, context: str = "Deletion") -> bool: from rich.prompt import Confirm if not confirm_yes and not Confirm.ask("Continue?"): logger.warning(f"{context} cancelled.") return False return True def insert_newlines(s: str, every: int = 79) -> str: return "\n".join(s[i : i + every] for i in range(0, len(s), every)) def parse_dot_notation(pairs: List[str]) -> dict: """ ["a.b=value1", "a.c.d=value2"] -> {"a": {"b": "value1", "c": {"d": "value2"}}} """ result = {} for pair in pairs: path, sep, value = pair.partition("=") if not sep: continue keys = path.strip().split(".") current = result for key in keys[:-1]: if not isinstance(current.get(key), dict): current[key] = {} current = current[key] current[keys[-1]] = value.strip() return result def upsert_by_discriminator(initial: List[dict], disc_key: str, config: dict) -> List[dict]: disc = config.get(disc_key) for i, d in enumerate(initial): if d.get(disc_key) == disc: initial[i] = config return initial initial.append(config) return initial def chunk_list(data: list, chunk_len: int, data_size: int = 1024, size_unit: str = "kb") -> List[list]: if size_unit.lower() == "mb": data_size *= 1024 result = [] current_chunk = [] for item in data: current_chunk.append(item) serialized_size = len(json.dumps(current_chunk).encode("utf-8")) / 1024 # convert bytes to kb if len(current_chunk) > chunk_len or serialized_size > data_size: current_chunk.pop() result.append(current_chunk) current_chunk = [item] if current_chunk: result.append(current_chunk) return result def to_safe_filename(name: str) -> str: return re.sub(r"[^\w\-.]", "_", name).strip(".")