src/azure-cli-core/azure/cli/core/util.py (988 lines of code) (raw):
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint: disable=too-many-lines
import base64
import binascii
import getpass
import json
import yaml
import logging
import os
import platform
import re
import ssl
import sys
from urllib.request import urlopen
from knack.log import get_logger
from knack.util import CLIError, to_snake_case, to_camel_case
logger = get_logger(__name__)
CLI_PACKAGE_NAME = 'azure-cli'
COMPONENT_PREFIX = 'azure-cli-'
SSLERROR_TEMPLATE = ('Certificate verification failed. This typically happens when using Azure CLI behind a proxy '
'that intercepts traffic with a self-signed certificate. '
# pylint: disable=line-too-long
'Please add this certificate to the trusted CA bundle. More info: https://learn.microsoft.com/cli/azure/use-cli-effectively#work-behind-a-proxy.')
QUERY_REFERENCE = ("To learn more about --query, please visit: "
"'https://learn.microsoft.com/cli/azure/query-azure-cli'")
_PROXYID_RE = re.compile(
'(?i)/subscriptions/(?P<subscription>[^/]*)(/resourceGroups/(?P<resource_group>[^/]*))?'
'(/providers/(?P<namespace>[^/]*)/(?P<type>[^/]*)/(?P<name>[^/]*)(?P<children>.*))?')
_CHILDREN_RE = re.compile('(?i)/(?P<child_type>[^/]*)/(?P<child_name>[^/]*)')
_VERSION_CHECK_TIME = 'check_time'
_VERSION_UPDATE_TIME = 'update_time'
# A list of reserved names that cannot be used as admin username of VM
DISALLOWED_USER_NAMES = [
"administrator", "admin", "user", "user1", "test", "user2",
"test1", "user3", "admin1", "1", "123", "a", "actuser", "adm",
"admin2", "aspnet", "backup", "console", "guest",
"owner", "root", "server", "sql", "support", "support_388945a0",
"sys", "test2", "test3", "user4", "user5"
]
def handle_exception(ex): # pylint: disable=too-many-locals, too-many-statements, too-many-branches
# For error code, follow guidelines at https://docs.python.org/2/library/sys.html#sys.exit,
from jmespath.exceptions import JMESPathError
from msrest.exceptions import HttpOperationError, ValidationError, ClientRequestError
from azure.common import AzureException
from azure.core.exceptions import AzureError, ServiceRequestError, HttpResponseError
from requests.exceptions import SSLError, HTTPError
from azure.cli.core import azclierror
from msal_extensions.persistence import PersistenceError
error_msg = getattr(ex, 'message', str(ex))
exit_code = 1
if isinstance(ex, azclierror.AzCLIError):
az_error = ex
elif isinstance(ex, JMESPathError):
error_msg = "Invalid jmespath query supplied for `--query`: {}".format(error_msg)
az_error = azclierror.InvalidArgumentValueError(error_msg)
az_error.set_recommendation(QUERY_REFERENCE)
# SSLError is raised when making HTTP requests with 'requests' lib behind a proxy that intercepts HTTPS traffic.
# - SSLError is raised when directly calling 'requests' lib, such as MSAL or `az rest`
# - azure.core.exceptions.ServiceRequestError is raised when indirectly calling 'requests' lib with azure.core,
# which wraps the original SSLError
elif isinstance(ex, SSLError) or isinstance(ex, ServiceRequestError) and isinstance(ex.inner_exception, SSLError):
az_error = azclierror.AzureConnectionError(error_msg)
az_error.set_recommendation(SSLERROR_TEMPLATE)
elif isinstance(ex, HttpResponseError):
if extract_common_error_message(ex):
error_msg = extract_common_error_message(ex)
status_code = str(getattr(ex, 'status_code', 'Unknown Code'))
AzCLIErrorType = get_error_type_by_status_code(status_code)
az_error = AzCLIErrorType(error_msg)
elif isinstance(ex, ValidationError):
az_error = azclierror.ValidationError(error_msg)
elif isinstance(ex, azclierror.HTTPError):
# For resources that don't support CAE - 401 can't be handled
if ex.response.status_code == 401 and 'WWW-Authenticate' in ex.response.headers:
az_error = azclierror.AuthenticationError(ex)
az_error.set_recommendation("Interactive authentication is needed. Please run:\naz logout\naz login")
else:
az_error = azclierror.UnclassifiedUserFault(ex)
elif isinstance(ex, CLIError):
# TODO: Fine-grained analysis here
az_error = azclierror.UnclassifiedUserFault(error_msg)
elif isinstance(ex, AzureError):
if extract_common_error_message(ex):
error_msg = extract_common_error_message(ex)
AzCLIErrorType = get_error_type_by_azure_error(ex)
az_error = AzCLIErrorType(error_msg)
elif isinstance(ex, AzureException):
if is_azure_connection_error(error_msg):
az_error = azclierror.AzureConnectionError(error_msg)
else:
# TODO: Fine-grained analysis here for Unknown error
az_error = azclierror.UnknownError(error_msg)
elif isinstance(ex, ClientRequestError):
if is_azure_connection_error(error_msg):
az_error = azclierror.AzureConnectionError(error_msg)
elif isinstance(ex.inner_exception, SSLError):
# When msrest encounters SSLError, msrest wraps SSLError in ClientRequestError
az_error = azclierror.AzureConnectionError(error_msg)
az_error.set_recommendation(SSLERROR_TEMPLATE)
else:
az_error = azclierror.ClientRequestError(error_msg)
elif isinstance(ex, HttpOperationError):
message, _ = extract_http_operation_error(ex)
if message:
error_msg = message
status_code = str(getattr(ex.response, 'status_code', 'Unknown Code'))
AzCLIErrorType = get_error_type_by_status_code(status_code)
az_error = AzCLIErrorType(error_msg)
elif isinstance(ex, HTTPError):
status_code = str(getattr(ex.response, 'status_code', 'Unknown Code'))
AzCLIErrorType = get_error_type_by_status_code(status_code)
az_error = AzCLIErrorType(error_msg)
elif isinstance(ex, KeyboardInterrupt):
error_msg = 'Keyboard interrupt is captured.'
az_error = azclierror.ManualInterrupt(error_msg)
elif isinstance(ex, PersistenceError):
# errno is already in strerror. str(ex) gives duplicated errno.
az_error = azclierror.CLIInternalError(ex.strerror)
if ex.errno == 0:
az_error.set_recommendation(
"Please report to us via Github: https://github.com/Azure/azure-cli/issues/20278")
elif ex.errno == -2146893813:
az_error.set_recommendation(
"Please report to us via Github: https://github.com/Azure/azure-cli/issues/20231")
elif ex.errno == -2146892987:
az_error.set_recommendation(
"Please report to us via Github: https://github.com/Azure/azure-cli/issues/21010")
else:
error_msg = "The command failed with an unexpected error. Here is the traceback:"
az_error = azclierror.CLIInternalError(error_msg)
az_error.set_exception_trace(ex)
az_error.set_recommendation(
"To check existing issues, please visit: https://github.com/Azure/azure-cli/issues")
if isinstance(az_error, azclierror.ResourceNotFoundError):
exit_code = 3
az_error.print_error()
az_error.send_telemetry()
return exit_code
def extract_common_error_message(ex):
error_msg = None
try:
error_msg = ex.args[0]
for detail in ex.args[0].error.details:
error_msg += ('\n' + detail)
except Exception: # pylint: disable=broad-except
pass
return error_msg
def extract_http_operation_error(ex):
error_msg = None
status_code = 'Unknown Code'
try:
response = json.loads(ex.response.text)
if isinstance(response, str):
error = response
else:
error = response.get('error', response.get('Error', None))
# ARM should use ODATA v4. So should try this first.
# http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091
if isinstance(error, dict):
status_code = error.get('code', error.get('Code', 'Unknown Code'))
message = error.get('message', error.get('Message', ex))
error_msg = "{}: {}".format(status_code, message)
else:
error_msg = error
except (ValueError, KeyError):
pass
return error_msg, status_code
def get_error_type_by_azure_error(ex):
from azure.core import exceptions
from azure.cli.core import azclierror
if isinstance(ex, exceptions.HttpResponseError):
status_code = str(ex.status_code)
return get_error_type_by_status_code(status_code)
if isinstance(ex, exceptions.ResourceNotFoundError):
return azclierror.ResourceNotFoundError
if isinstance(ex, exceptions.ServiceRequestError):
return azclierror.ClientRequestError
if isinstance(ex, exceptions.ServiceRequestTimeoutError):
return azclierror.AzureConnectionError
if isinstance(ex, (exceptions.ServiceResponseError, exceptions.ServiceResponseTimeoutError)):
return azclierror.AzureResponseError
return azclierror.UnknownError
# pylint: disable=too-many-return-statements
def get_error_type_by_status_code(status_code):
from azure.cli.core import azclierror
if status_code == '400':
return azclierror.BadRequestError
if status_code == '401':
return azclierror.UnauthorizedError
if status_code == '403':
return azclierror.ForbiddenError
if status_code == '404':
return azclierror.ResourceNotFoundError
if status_code.startswith('4'):
return azclierror.UnclassifiedUserFault
if status_code.startswith('5'):
return azclierror.AzureInternalError
return azclierror.UnknownError
def is_azure_connection_error(error_msg):
error_msg = error_msg.lower()
if 'connection error' in error_msg \
or 'connection broken' in error_msg \
or 'connection aborted' in error_msg:
return True
return False
# pylint: disable=inconsistent-return-statements
def empty_on_404(ex):
from azure.core.exceptions import HttpResponseError
if isinstance(ex, HttpResponseError) and ex.status_code == 404:
return None
raise ex
def truncate_text(str_to_shorten, width=70, placeholder=' [...]'):
if width <= 0:
raise ValueError('width must be greater than 0.')
s_len = width - len(placeholder)
return str_to_shorten[:s_len] + (str_to_shorten[s_len:] and placeholder)
def get_installed_cli_distributions():
# Stop importing pkg_resources, because importing it is slow (~200ms).
# from pkg_resources import working_set
# return [d for d in list(working_set) if d.key == CLI_PACKAGE_NAME or d.key.startswith(COMPONENT_PREFIX)]
# Use the hard-coded version instead of querying all modules under site-packages.
from azure.cli.core import __version__ as azure_cli_core_version
from azure.cli.telemetry import __version__ as azure_cli_telemetry_version
class VersionItem: # pylint: disable=too-few-public-methods
"""A mock of pkg_resources.EggInfoDistribution to maintain backward compatibility."""
def __init__(self, key, version):
self.key = key
self.version = version
return [
VersionItem('azure-cli', azure_cli_core_version),
VersionItem('azure-cli-core', azure_cli_core_version),
VersionItem('azure-cli-telemetry', azure_cli_telemetry_version)
]
def get_latest_from_github(package_path='azure-cli'):
try:
import requests
git_url = "https://raw.githubusercontent.com/Azure/azure-cli/main/src/{}/setup.py".format(package_path)
response = requests.get(git_url, timeout=10)
if response.status_code != 200:
logger.info("Failed to fetch the latest version from '%s' with status code '%s' and reason '%s'",
git_url, response.status_code, response.reason)
return None
for line in response.iter_lines():
txt = line.decode('utf-8', errors='ignore')
if txt.startswith('VERSION'):
match = re.search(r'VERSION = \"(.*)\"$', txt)
if match:
return match.group(1)
except Exception as ex: # pylint: disable=broad-except
logger.info("Failed to get the latest version from '%s'. %s", git_url, str(ex))
return None
def _update_latest_from_github(versions):
if not check_connectivity(url='https://raw.githubusercontent.com', max_retries=0):
return versions, False
success = True
for pkg in ['azure-cli-core', 'azure-cli-telemetry']:
version = get_latest_from_github(pkg)
if not version:
success = False
else:
versions[pkg.replace(COMPONENT_PREFIX, '')]['pypi'] = version
try:
versions[CLI_PACKAGE_NAME]['pypi'] = versions['core']['pypi']
except KeyError:
pass
return versions, success
def get_cached_latest_versions(versions=None):
""" Get the latest versions from a cached file"""
import datetime
from azure.cli.core._session import VERSIONS
if not versions:
versions = _get_local_versions()
if VERSIONS[_VERSION_UPDATE_TIME]:
version_update_time = datetime.datetime.strptime(VERSIONS[_VERSION_UPDATE_TIME], '%Y-%m-%d %H:%M:%S.%f')
if datetime.datetime.now() < version_update_time + datetime.timedelta(days=1):
cache_versions = VERSIONS['versions']
if cache_versions and cache_versions['azure-cli']['local'] == versions['azure-cli']['local']:
return cache_versions.copy(), True
versions, success = _update_latest_from_github(versions)
VERSIONS['versions'] = versions
VERSIONS[_VERSION_UPDATE_TIME] = str(datetime.datetime.now())
return versions.copy(), success
def _get_local_versions():
# get locally installed versions
versions = {}
for dist in get_installed_cli_distributions():
if dist.key == CLI_PACKAGE_NAME:
versions[CLI_PACKAGE_NAME] = {'local': dist.version}
elif dist.key.startswith(COMPONENT_PREFIX):
comp_name = dist.key.replace(COMPONENT_PREFIX, '')
versions[comp_name] = {'local': dist.version}
return versions
def get_az_version_string(use_cache=False): # pylint: disable=too-many-statements
from azure.cli.core.extension import get_extensions, EXTENSIONS_DIR, DEV_EXTENSION_SOURCES, EXTENSIONS_SYS_DIR
from azure.cli.core._environment import get_config_dir
import io
output = io.StringIO()
versions = _get_local_versions()
# get the versions from pypi
versions, success = get_cached_latest_versions(versions) if use_cache else _update_latest_from_github(versions)
updates_available_components = []
def _print(val=''):
print(val, file=output)
def _get_version_string(name, version_dict):
from packaging.version import parse # pylint: disable=import-error,no-name-in-module
local = version_dict['local']
pypi = version_dict.get('pypi', None)
if pypi and parse(pypi) > parse(local):
return name.ljust(25) + local.rjust(15) + ' *'
return name.ljust(25) + local.rjust(15)
ver_string = _get_version_string(CLI_PACKAGE_NAME, versions.pop(CLI_PACKAGE_NAME))
if '*' in ver_string:
updates_available_components.append(CLI_PACKAGE_NAME)
_print(ver_string)
_print()
for name in sorted(versions.keys()):
ver_string = _get_version_string(name, versions.pop(name))
if '*' in ver_string:
updates_available_components.append(name)
_print(ver_string)
_print()
extensions = get_extensions()
if extensions:
_print('Extensions:')
for ext in extensions:
if ext.ext_type == 'dev':
_print(ext.name.ljust(20) + (ext.version or 'Unknown').rjust(20) + ' (dev) ' + ext.path)
else:
_print(ext.name.ljust(20) + (ext.version or 'Unknown').rjust(20))
_print()
_print('Dependencies:')
dependencies_versions = get_dependency_versions()
for k, v in dependencies_versions.items():
_print(k.ljust(20) + v.rjust(20))
_print()
_print("Python location '{}'".format(os.path.abspath(sys.executable)))
_print("Config directory '{}'".format(get_config_dir()))
_print("Extensions directory '{}'".format(EXTENSIONS_DIR))
if os.path.isdir(EXTENSIONS_SYS_DIR) and os.listdir(EXTENSIONS_SYS_DIR):
_print("Extensions system directory '{}'".format(EXTENSIONS_SYS_DIR))
if DEV_EXTENSION_SOURCES:
_print("Development extension sources:")
for source in DEV_EXTENSION_SOURCES:
_print(' {}'.format(source))
_print()
_print('Python ({}) {}'.format(platform.system(), sys.version))
_print()
_print('Legal docs and information: aka.ms/AzureCliLegal')
_print()
version_string = output.getvalue()
# if unable to query PyPI, use sentinel value to flag that
# we couldn't check for updates
if not success:
updates_available_components = None
return version_string, updates_available_components
def get_az_version_json():
from azure.cli.core.extension import get_extensions
versions = {'extensions': {}}
for dist in get_installed_cli_distributions():
versions[dist.key] = dist.version
extensions = get_extensions()
if extensions:
for ext in extensions:
versions['extensions'][ext.name] = ext.version or 'Unknown'
return versions
def get_dependency_versions():
versions = {}
# Add msal version
try:
from msal import __version__ as msal_version
except ImportError:
msal_version = "N/A"
versions['msal'] = msal_version
# Add azure-mgmt-resource version
try:
# Track 2 >=15.0.0
# pylint: disable=protected-access
from azure.mgmt.resource._version import VERSION as azure_mgmt_resource_version
except ImportError:
try:
# Track 1 <=13.0.0
from azure.mgmt.resource.version import VERSION as azure_mgmt_resource_version
except ImportError:
azure_mgmt_resource_version = "N/A"
versions['azure-mgmt-resource'] = azure_mgmt_resource_version
return versions
def show_updates_available(new_line_before=False, new_line_after=False):
from azure.cli.core._session import VERSIONS
import datetime
if VERSIONS[_VERSION_CHECK_TIME]:
version_check_time = datetime.datetime.strptime(VERSIONS[_VERSION_CHECK_TIME], '%Y-%m-%d %H:%M:%S.%f')
if datetime.datetime.now() < version_check_time + datetime.timedelta(days=7):
return
_, updates_available_components = get_az_version_string(use_cache=True)
if updates_available_components:
if new_line_before:
logger.warning("")
show_updates(updates_available_components, only_show_when_updates_available=True)
if new_line_after:
logger.warning("")
VERSIONS[_VERSION_CHECK_TIME] = str(datetime.datetime.now())
def show_updates(updates_available_components, only_show_when_updates_available=False):
if updates_available_components is None:
if not only_show_when_updates_available:
logger.warning('Unable to check if your CLI is up-to-date. Check your internet connection.')
elif updates_available_components: # pylint: disable=too-many-nested-blocks
if in_cloud_console():
warning_msg = 'You have %i update(s) available. They will be updated with the next build of Cloud Shell.'
else:
warning_msg = "You have %i update(s) available."
if CLI_PACKAGE_NAME in updates_available_components:
warning_msg = "{} Consider updating your CLI installation with 'az upgrade'".format(warning_msg)
logger.warning(warning_msg, len(updates_available_components))
elif not only_show_when_updates_available:
print('Your CLI is up-to-date.')
def get_json_object(json_string):
""" Loads a JSON string as an object and converts all keys to snake case """
def _convert_to_snake_case(item):
if isinstance(item, dict):
new_item = {}
for key, val in item.items():
new_item[to_snake_case(key)] = _convert_to_snake_case(val)
return new_item
if isinstance(item, list):
return [_convert_to_snake_case(x) for x in item]
return item
return _convert_to_snake_case(shell_safe_json_parse(json_string))
def get_file_json(file_path, throw_on_empty=True, preserve_order=False):
content = read_file_content(file_path)
if not content and not throw_on_empty:
return None
try:
return shell_safe_json_parse(content, preserve_order)
except CLIError as ex:
# Reading file bypasses shell interpretation, so we discard the recommendation for shell quoting.
raise CLIError("Failed to parse file '{}' with exception:\n{}".format(file_path, ex))
def get_file_yaml(file_path, throw_on_empty=True):
content = read_file_content(file_path)
if not content:
if throw_on_empty:
raise CLIError("Failed to parse file '{}' with exception:\nNo content in the file.".format(file_path))
return None
try:
return yaml.safe_load(content)
except yaml.parser.ParserError as ex:
raise CLIError("Failed to parse file '{}' with exception:\n{}".format(file_path, ex)) from ex
def read_file_content(file_path, allow_binary=False):
from codecs import open as codecs_open
# Note, always put 'utf-8-sig' first, so that BOM in WinOS won't cause trouble.
for encoding in ['utf-8-sig', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be']:
try:
with codecs_open(file_path, encoding=encoding) as f:
logger.debug("attempting to read file %s as %s", file_path, encoding)
return f.read()
except (UnicodeError, UnicodeDecodeError):
pass
if allow_binary:
try:
with open(file_path, 'rb') as input_file:
logger.debug("attempting to read file %s as binary", file_path)
return base64.b64encode(input_file.read()).decode("utf-8")
except Exception: # pylint: disable=broad-except
pass
raise CLIError('Failed to decode file {} - unknown decoding'.format(file_path))
def shell_safe_json_parse(json_or_dict_string, preserve_order=False, strict=True):
""" Allows the passing of JSON or Python dictionary strings. This is needed because certain
JSON strings in CMD shell are not received in main's argv. This allows the user to specify
the alternative notation, which does not have this problem (but is technically not JSON). """
try:
if not preserve_order:
return json.loads(json_or_dict_string, strict=strict)
from collections import OrderedDict
return json.loads(json_or_dict_string, object_pairs_hook=OrderedDict, strict=strict)
except ValueError as json_ex:
try:
import ast
return ast.literal_eval(json_or_dict_string)
except Exception as ex:
logger.debug(ex) # log the exception which could be a python dict parsing error.
# Echo the string received by CLI. Because the user may intend to provide a file path, we don't decisively
# say it is a JSON string.
msg = "Failed to parse string as JSON:\n{}\nError detail: {}".format(json_or_dict_string, json_ex)
# Recommendation for all shells
from azure.cli.core.azclierror import InvalidArgumentValueError
recommendation = "The provided JSON string may have been parsed by the shell. See " \
"https://learn.microsoft.com/cli/azure/use-azure-cli-successfully-quoting#json-strings"
# Recommendation especially for PowerShell
parent_proc = get_parent_proc_name()
if parent_proc and parent_proc.lower() in ("powershell.exe", "pwsh.exe"):
recommendation += "\nPowerShell requires additional quoting rules. See " \
"https://github.com/Azure/azure-cli/blob/dev/doc/quoting-issues-with-powershell.md"
# Raise from json_ex error which is more likely to be the original error
raise InvalidArgumentValueError(msg, recommendation=recommendation) from json_ex
def b64encode(s):
"""
Encodes a string to base64 on 2.x and 3.x
:param str s: latin_1 encoded string
:return: base64 encoded string
:rtype: str
"""
encoded = base64.b64encode(s.encode("latin-1"))
return encoded if encoded is str else encoded.decode('latin-1')
def b64_to_hex(s):
"""
Decodes a string to base64 on 2.x and 3.x
:param str s: base64 encoded string
:return: uppercase hex string
:rtype: str
"""
decoded = base64.b64decode(s)
hex_data = binascii.hexlify(decoded).upper()
if isinstance(hex_data, bytes):
return str(hex_data.decode("utf-8"))
return hex_data
def todict(obj, post_processor=None):
"""
Convert an object to a dictionary. Use 'post_processor(original_obj, dictionary)' to update the
dictionary in the process
"""
from datetime import date, time, datetime, timedelta
from enum import Enum
if isinstance(obj, dict):
result = {k: todict(v, post_processor) for (k, v) in obj.items()}
return post_processor(obj, result) if post_processor else result
if isinstance(obj, list):
return [todict(a, post_processor) for a in obj]
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, (date, time, datetime)):
return obj.isoformat()
if isinstance(obj, timedelta):
return str(obj)
# This is the only difference with knack.util.todict because for typespec generated SDKs
# The base model stores data in obj.__dict__['_data'] instead of in obj.__dict__
# We need to call obj.as_dict() to extract data for this kind of model
if hasattr(obj, 'as_dict') and not hasattr(obj, '_attribute_map'):
result = {to_camel_case(k): todict(v, post_processor) for k, v in obj.as_dict().items()}
return post_processor(obj, result) if post_processor else result
if hasattr(obj, '_asdict'):
return todict(obj._asdict(), post_processor)
if hasattr(obj, '__dict__'):
result = {to_camel_case(k): todict(v, post_processor)
for k, v in obj.__dict__.items()
if not callable(v) and not k.startswith('_')}
return post_processor(obj, result) if post_processor else result
return obj
def random_string(length=16, force_lower=False, digits_only=False):
from string import ascii_letters, digits, ascii_lowercase
from random import choice
choice_set = digits
if not digits_only:
choice_set += ascii_lowercase if force_lower else ascii_letters
return ''.join([choice(choice_set) for _ in range(length)])
def hash_string(value, length=16, force_lower=False):
""" Generate a deterministic hashed string."""
import hashlib
m = hashlib.sha256()
try:
m.update(value)
except TypeError:
m.update(value.encode())
digest = m.hexdigest()
digest = digest.lower() if force_lower else digest
while len(digest) < length:
digest = digest + digest
return digest[:length]
def in_cloud_console():
return os.environ.get('ACC_CLOUD', None)
def get_arg_list(op):
import inspect
sig = inspect.signature(op)
return sig.parameters
def is_track2(client_class):
""" IS this client a autorestv3/track2 one?.
Could be refined later if necessary.
"""
from inspect import getfullargspec as get_arg_spec
args = get_arg_spec(client_class.__init__).args
return "credential" in args
DISABLE_VERIFY_VARIABLE_NAME = "AZURE_CLI_DISABLE_CONNECTION_VERIFICATION"
def should_disable_connection_verify():
return bool(os.environ.get(DISABLE_VERIFY_VARIABLE_NAME))
def poller_classes():
from msrest.polling.poller import LROPoller
from azure.core.polling import LROPoller as AzureCoreLROPoller
from azure.cli.core.aaz._poller import AAZLROPoller
return (LROPoller, AzureCoreLROPoller, AAZLROPoller)
def augment_no_wait_handler_args(no_wait_enabled, handler, handler_args):
""" Populates handler_args with the appropriate args for no wait """
h_args = get_arg_list(handler)
if 'no_wait' in h_args:
handler_args['no_wait'] = no_wait_enabled
if 'raw' in h_args and no_wait_enabled:
# support autorest 2
handler_args['raw'] = True
if 'polling' in h_args and no_wait_enabled:
# support autorest 3
handler_args['polling'] = False
# Support track2 SDK.
# In track2 SDK, there is no parameter 'polling' in SDK, but just use '**kwargs'.
# So we check the name of the operation to see if it's a long running operation.
# The name of long running operation in SDK is like 'begin_xxx_xxx'.
op_name = handler.__name__
if op_name and op_name.startswith('begin_') and no_wait_enabled:
handler_args['polling'] = False
def sdk_no_wait(no_wait, func, *args, **kwargs):
if no_wait:
kwargs.update({'polling': False})
return func(*args, **kwargs)
def open_page_in_browser(url):
import subprocess
import webbrowser
platform_name, _ = _get_platform_info()
if is_wsl(): # windows 10 linux subsystem
try:
# https://learn.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_powershell_exe
# Ampersand (&) should be quoted
return subprocess.Popen(
['powershell.exe', '-NoProfile', '-Command', 'Start-Process "{}"'.format(url)]).wait()
except OSError: # WSL might be too old # FileNotFoundError introduced in Python 3
pass
elif platform_name == 'darwin':
# handle 2 things:
# a. On OSX sierra, 'python -m webbrowser -t <url>' emits out "execution error: <url> doesn't
# understand the "open location" message"
# b. Python 2.x can't sniff out the default browser
return subprocess.Popen(['open', url])
try:
return webbrowser.open(url, new=2) # 2 means: open in a new tab, if possible
except TypeError: # See https://bugs.python.org/msg322439
return webbrowser.open(url, new=2)
def _get_platform_info():
uname = platform.uname()
return uname.system.lower(), uname.release.lower()
def is_wsl():
platform_name, release = _get_platform_info()
# "Official" way of detecting WSL: https://github.com/Microsoft/WSL/issues/423#issuecomment-221627364
# Run `uname -a` to get 'release' without python
# - WSL 1: '4.4.0-19041-Microsoft'
# - WSL 2: '4.19.128-microsoft-standard'
return platform_name == 'linux' and 'microsoft' in release
def is_windows():
platform_name, _ = _get_platform_info()
return platform_name == 'windows'
def is_github_codespaces():
# https://docs.github.com/en/codespaces/developing-in-a-codespace/default-environment-variables-for-your-codespace
return os.environ.get('CODESPACES') == 'true'
def can_launch_browser():
import webbrowser
platform_name, _ = _get_platform_info()
if platform_name != 'linux':
# Only Linux may have no browser
return True
# Using webbrowser to launch a browser is the preferred way.
try:
webbrowser.get()
return True
except webbrowser.Error:
# Don't worry. We may still try powershell.exe.
pass
if is_wsl():
# Docker container running on WSL 2 also shows WSL, but it can't launch a browser.
# If powershell.exe is on PATH, it can be called to launch a browser.
import shutil
if shutil.which("powershell.exe"):
return True
return False
def get_command_type_kwarg(custom_command=False):
return 'custom_command_type' if custom_command else 'command_type'
def reload_module(module):
# reloading the imported module to update
if module in sys.modules:
from importlib import reload
reload(sys.modules[module])
def get_default_admin_username():
try:
username = getpass.getuser()
except KeyError:
username = None
if username is None or username.lower() in DISALLOWED_USER_NAMES:
logger.warning('Default username %s is a reserved username. Use azureuser instead.', username)
username = 'azureuser'
return username
def _find_child(parent, *args, **kwargs):
# tuple structure (path, key, dest)
path = kwargs.get('path', None)
key_path = kwargs.get('key_path', None)
comps = zip(path.split('.'), key_path.split('.'), args)
current = parent
for path, key, val in comps:
current = getattr(current, path, None)
if current is None:
raise CLIError("collection '{}' not found".format(path))
match = next((x for x in current if getattr(x, key).lower() == val.lower()), None)
if match is None:
raise CLIError("item '{}' not found in {}".format(val, path))
current = match
return current
def find_child_item(parent, *args, **kwargs):
path = kwargs.get('path', '')
key_path = kwargs.get('key_path', '')
if len(args) != len(path.split('.')) != len(key_path.split('.')):
raise CLIError('command authoring error: args, path and key_path must have equal number of components.')
return _find_child(parent, *args, path=path, key_path=key_path)
def find_child_collection(parent, *args, **kwargs):
path = kwargs.get('path', '')
key_path = kwargs.get('key_path', '')
arg_len = len(args)
key_len = len(key_path.split('.'))
path_len = len(path.split('.'))
if arg_len != key_len and path_len != arg_len + 1:
raise CLIError('command authoring error: args and key_path must have equal number of components, and '
'path must have one extra component (the path to the collection of interest.')
parent = _find_child(parent, *args, path=path, key_path=key_path)
collection_path = path.split('.')[-1]
collection = getattr(parent, collection_path, None)
if collection is None:
raise CLIError("collection '{}' not found".format(collection_path))
return collection
def check_connectivity(url='https://azure.microsoft.com', max_retries=5, timeout=1):
import requests
import timeit
start = timeit.default_timer()
success = None
try:
with requests.Session() as s:
s.mount(url, requests.adapters.HTTPAdapter(max_retries=max_retries))
s.head(url, timeout=timeout)
success = True
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as ex:
logger.info('Connectivity problem detected.')
logger.debug(ex)
success = False
stop = timeit.default_timer()
logger.debug('Connectivity check: %s sec', stop - start)
return success
def send_raw_request(cli_ctx, method, url, headers=None, uri_parameters=None, # pylint: disable=too-many-locals,too-many-branches,too-many-statements
body=None, skip_authorization_header=False, resource=None, output_file=None,
generated_client_request_id_name='x-ms-client-request-id'):
import uuid
from requests import Session, Request
from requests.structures import CaseInsensitiveDict
result = CaseInsensitiveDict()
for s in headers or []:
try:
temp = shell_safe_json_parse(s)
result.update(temp)
except CLIError:
key, value = s.split('=', 1)
result[key] = value
headers = result
# If Authorization header is already provided, don't bother with the token
if 'Authorization' in headers:
skip_authorization_header = True
# Handle User-Agent
agents = [get_az_rest_user_agent()]
# Borrow AZURE_HTTP_USER_AGENT from msrest
# https://github.com/Azure/msrest-for-python/blob/4cc8bc84e96036f03b34716466230fb257e27b36/msrest/pipeline/universal.py#L70
_ENV_ADDITIONAL_USER_AGENT = 'AZURE_HTTP_USER_AGENT'
if _ENV_ADDITIONAL_USER_AGENT in os.environ:
agents.append(os.environ[_ENV_ADDITIONAL_USER_AGENT])
# Custom User-Agent provided as command argument
if 'User-Agent' in headers:
agents.append(headers['User-Agent'])
headers['User-Agent'] = ' '.join(agents)
from azure.cli.core.telemetry import set_user_agent
set_user_agent(headers['User-Agent'])
if generated_client_request_id_name:
headers[generated_client_request_id_name] = str(uuid.uuid4())
# try to figure out the correct content type
if body:
try:
body_object = shell_safe_json_parse(body)
# Make sure Unicode characters are escaped as ASCII by utilizing the default ensure_ascii=True kwarg
# of json.dumps, since http.client by default encodes the body as latin-1:
# https://github.com/python/cpython/blob/3.10/Lib/http/client.py#L164
# https://github.com/python/cpython/blob/3.10/Lib/http/client.py#L1324-L1327
body = json.dumps(body_object)
if 'Content-Type' not in headers:
headers['Content-Type'] = 'application/json'
except Exception: # pylint: disable=broad-except
pass
# add telemetry
headers['CommandName'] = cli_ctx.data['command']
if cli_ctx.data.get('safe_params'):
headers['ParameterSetName'] = ' '.join(cli_ctx.data['safe_params'])
result = {}
for s in uri_parameters or []:
try:
temp = shell_safe_json_parse(s)
result.update(temp)
except CLIError:
key, value = s.split('=', 1)
result[key] = value
uri_parameters = result or None
endpoints = cli_ctx.cloud.endpoints
# If url is an ARM resource ID, like /subscriptions/xxx/resourcegroups/xxx?api-version=2019-07-01,
# default to Azure Resource Manager.
# https://management.azure.com + /subscriptions/xxx/resourcegroups/xxx?api-version=2019-07-01
if '://' not in url:
url = endpoints.resource_manager.rstrip('/') + url
# Replace common tokens with real values. It is for smooth experience if users copy and paste the url from
# Azure Rest API doc
from azure.cli.core._profile import Profile
profile = Profile(cli_ctx=cli_ctx)
if '{subscriptionId}' in url:
url = url.replace('{subscriptionId}', cli_ctx.data['subscription_id'] or profile.get_subscription_id())
# Prepare the Bearer token for `Authorization` header
if not skip_authorization_header and url.lower().startswith('https://'):
# Prepare `resource` for `get_raw_token`
if not resource:
# If url starts with ARM endpoint, like `https://management.azure.com/`,
# use `active_directory_resource_id` for resource, like `https://management.core.windows.net/`.
# This follows the same behavior as `azure.cli.core.commands.client_factory._get_mgmt_service_client`
if url.lower().startswith(endpoints.resource_manager.rstrip('/')):
resource = endpoints.active_directory_resource_id
else:
from azure.cli.core.cloud import CloudEndpointNotSetException
for p in [x for x in dir(endpoints) if not x.startswith('_')]:
try:
value = getattr(endpoints, p)
except CloudEndpointNotSetException:
continue
if isinstance(value, str) and url.lower().startswith(value.lower()):
resource = value
break
if resource:
# Prepare `subscription` for `get_raw_token`
# If this is an ARM request, try to extract subscription ID from the URL.
# But there are APIs which don't require subscription ID, like /subscriptions, /tenants
# TODO: In the future when multi-tenant subscription is supported, we won't be able to uniquely identify
# the token from subscription anymore.
token_subscription = None
if url.lower().startswith(endpoints.resource_manager.rstrip('/')):
token_subscription = _extract_subscription_id(url)
if token_subscription:
logger.debug('Retrieving token for resource %s, subscription %s', resource, token_subscription)
token_info, _, _ = profile.get_raw_token(resource, subscription=token_subscription)
else:
logger.debug('Retrieving token for resource %s', resource)
token_info, _, _ = profile.get_raw_token(resource)
token_type, token, _ = token_info
headers = headers or {}
headers['Authorization'] = '{} {}'.format(token_type, token)
else:
logger.warning("Can't derive appropriate Azure AD resource from --url to acquire an access token. "
"If access token is required, use --resource to specify the resource")
# https://requests.readthedocs.io/en/latest/user/advanced/#prepared-requests
s = Session()
req = Request(method=method, url=url, headers=headers, params=uri_parameters, data=body)
prepped = s.prepare_request(req)
# Merge environment settings into session
settings = s.merge_environment_settings(prepped.url, {}, None, not should_disable_connection_verify(), None)
_log_request(prepped)
r = s.send(prepped, **settings)
_log_response(r)
if not r.ok:
reason = r.reason
if r.text:
reason += '({})'.format(r.text)
from .azclierror import HTTPError
raise HTTPError(reason, r)
if output_file:
with open(output_file, 'wb') as fd:
for chunk in r.iter_content(chunk_size=128):
fd.write(chunk)
return r
def _extract_subscription_id(url):
"""Extract the subscription ID from an ARM request URL."""
subscription_regex = '/subscriptions/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})'
match = re.search(subscription_regex, url, re.IGNORECASE)
if match:
subscription_id = match.groups()[0]
logger.debug('Found subscription ID %s in the URL %s', subscription_id, url)
return subscription_id
logger.debug('No subscription ID specified in the URL %s', url)
return None
def _log_request(request):
"""Log a client request. Copied from msrest
https://github.com/Azure/msrest-for-python/blob/3653d29fc44da408898b07c710290a83d196b777/msrest/http_logger.py#L39
"""
if not logger.isEnabledFor(logging.DEBUG):
return
try:
logger.info("Request URL: %r", request.url)
logger.info("Request method: %r", request.method)
logger.info("Request headers:")
for header, value in request.headers.items():
if header.lower() == 'authorization':
# Trim at least half of the token but keep at most 20 characters
preserve_length = min(int(len(value) * 0.5), 20)
value = value[:preserve_length] + '...'
logger.info(" %r: %r", header, value)
logger.info("Request body:")
# We don't want to log the binary data of a file upload.
import types
if isinstance(request.body, types.GeneratorType):
logger.info("File upload")
else:
logger.info(str(request.body))
except Exception as err: # pylint: disable=broad-except
logger.info("Failed to log request: %r", err)
def _log_response(response, **kwargs):
"""Log a server response. Copied from msrest
https://github.com/Azure/msrest-for-python/blob/3653d29fc44da408898b07c710290a83d196b777/msrest/http_logger.py#L68
"""
if not logger.isEnabledFor(logging.DEBUG):
return None
try:
logger.info("Response status: %r", response.status_code)
logger.info("Response headers:")
for res_header, value in response.headers.items():
logger.info(" %r: %r", res_header, value)
# We don't want to log binary data if the response is a file.
logger.info("Response content:")
pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
header = response.headers.get('content-disposition')
if header and pattern.match(header):
filename = header.partition('=')[2]
logger.info("File attachments: %s", filename)
elif response.headers.get("content-type", "").endswith("octet-stream"):
logger.info("Body contains binary data.")
elif response.headers.get("content-type", "").startswith("image"):
logger.info("Body contains image data.")
else:
if kwargs.get('stream', False):
logger.info("Body is streamable")
else:
logger.info(response.content.decode("utf-8-sig"))
return response
except Exception as err: # pylint: disable=broad-except
logger.info("Failed to log response: %s", repr(err))
return response
class ScopedConfig:
def __init__(self, cli_config, use_local_config=None):
self.use_local_config = use_local_config
if self.use_local_config is None:
self.use_local_config = False
self.cli_config = cli_config
# here we use getattr/setattr to prepare the situation that "use_local_config" might not be available
self.original_use_local_config = getattr(cli_config, 'use_local_config', None)
def __enter__(self):
self.cli_config.use_local_config = self.use_local_config
def __exit__(self, exc_type, exc_val, exc_tb):
setattr(self.cli_config, 'use_local_config', self.original_use_local_config)
ConfiguredDefaultSetter = ScopedConfig
def _ssl_context():
if sys.version_info < (3, 4) or (in_cloud_console() and platform.system() == 'Windows'):
try:
return ssl.SSLContext(ssl.PROTOCOL_TLS) # added in python 2.7.13 and 3.6
except AttributeError:
return ssl.SSLContext(ssl.PROTOCOL_TLSv1)
return ssl.create_default_context()
def urlretrieve(url):
req = urlopen(url, context=_ssl_context())
return req.read()
def parse_proxy_resource_id(rid):
"""Parses a resource_id into its various parts.
Return an empty dictionary, if invalid resource id.
:param rid: The resource id being parsed
:type rid: str
:returns: A dictionary with with following key/value pairs (if found):
- subscription: Subscription id
- resource_group: Name of resource group
- namespace: Namespace for the resource provider (i.e. Microsoft.Compute)
- type: Type of the root resource (i.e. virtualMachines)
- name: Name of the root resource
- child_type_{level}: Type of the child resource of that level
- child_name_{level}: Name of the child resource of that level
- last_child_num: Level of the last child
:rtype: dict[str,str]
"""
if not rid:
return {}
match = _PROXYID_RE.match(rid)
if match:
result = match.groupdict()
children = _CHILDREN_RE.finditer(result['children'] or '')
count = None
for count, child in enumerate(children):
result.update({
key + '_%d' % (count + 1): group for key, group in child.groupdict().items()})
result['last_child_num'] = count + 1 if isinstance(count, int) else None
result.pop('children', None)
return {key: value for key, value in result.items() if value is not None}
return None
def get_az_user_agent():
# Dynamically load the core version
from azure.cli.core import __version__ as core_version
agents = ["AZURECLI/{}".format(core_version)]
from azure.cli.core._environment import _ENV_AZ_INSTALLER
if _ENV_AZ_INSTALLER in os.environ:
agents.append('({})'.format(os.environ[_ENV_AZ_INSTALLER]))
# msrest already has this
# https://github.com/Azure/msrest-for-python/blob/4cc8bc84e96036f03b34716466230fb257e27b36/msrest/pipeline/universal.py#L70
# if ENV_ADDITIONAL_USER_AGENT in os.environ:
# agents.append(os.environ[ENV_ADDITIONAL_USER_AGENT])
return ' '.join(agents)
def get_az_rest_user_agent():
"""Get User-Agent for az rest calls"""
agents = ['python/{}'.format(platform.python_version()),
'({})'.format(platform.platform()),
get_az_user_agent()
]
return ' '.join(agents)
def user_confirmation(message, yes=False):
if yes:
return
from knack.prompting import prompt_y_n, NoTTYException
try:
if not prompt_y_n(message):
raise CLIError('Operation cancelled.')
except NoTTYException:
raise CLIError(
'Unable to prompt for confirmation as no tty available. Use --yes.')
def get_linux_distro():
if platform.system() != 'Linux':
return None, None
try:
with open('/etc/os-release') as lines:
tokens = [line.strip() for line in lines]
except Exception: # pylint: disable=broad-except
return None, None
release_info = {}
for token in tokens:
if '=' in token:
k, v = token.split('=', 1)
release_info[k.lower()] = v.strip('"')
return release_info.get('name', None), release_info.get('version_id', None)
def roughly_parse_command(args):
# Roughly parse the command part: <az vm create> --name vm1
# Similar to knack.invocation.CommandInvoker._rudimentary_get_command, but we don't need to bother with
# positional args
nouns = []
for arg in args:
if arg and arg[0] != '-':
nouns.append(arg)
else:
break
return ' '.join(nouns).lower()
def is_guid(guid):
import uuid
try:
uuid.UUID(guid)
return True
except ValueError:
return False
def handle_version_update():
"""Clean up information in local files that may be invalidated
because of a version update of Azure CLI
"""
try:
from azure.cli.core._session import VERSIONS
from packaging.version import parse # pylint: disable=import-error,no-name-in-module
from azure.cli.core import __version__
if not VERSIONS['versions']:
get_cached_latest_versions()
elif parse(VERSIONS['versions']['core']['local']) != parse(__version__):
logger.debug("Azure CLI has been updated.")
logger.debug("Clean up versions and refresh cloud endpoints information in local files.")
VERSIONS['versions'] = {}
VERSIONS['update_time'] = ''
from azure.cli.core.cloud import refresh_known_clouds
refresh_known_clouds()
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
def _get_parent_proc_name():
# Un-cached function to get parent process name.
try:
import psutil
except ImportError as ex:
logger.debug(ex)
return None
try:
parent = psutil.Process(os.getpid()).parent()
# On Windows, when CLI is run inside a virtual env, there will be 2 python.exe.
if parent and parent.name().lower() == 'python.exe':
parent = parent.parent()
if parent:
# On Windows, powershell.exe launches cmd.exe to launch python.exe.
grandparent = parent.parent()
if grandparent:
grandparent_name = grandparent.name().lower()
if grandparent_name in ("powershell.exe", "pwsh.exe"):
return grandparent.name()
# if powershell.exe or pwsh.exe is not the grandparent, simply return the parent's name.
return parent.name()
except psutil.AccessDenied as ex:
# Ignore due to https://github.com/giampaolo/psutil/issues/1980
logger.debug(ex)
return None
def get_parent_proc_name():
# This function wraps _get_parent_proc_name, as psutil calls are time-consuming, so use a
# function-level cache to save the result.
# NOTE: The return value may be None if getting parent proc name fails, so always remember to
# check it first before calling string methods like lower().
if not hasattr(get_parent_proc_name, "return_value"):
parent_proc_name = _get_parent_proc_name()
setattr(get_parent_proc_name, "return_value", parent_proc_name)
return getattr(get_parent_proc_name, "return_value")
def is_modern_terminal():
"""In addition to knack.util.is_modern_terminal, detect Cloud Shell."""
import knack.util
return knack.util.is_modern_terminal() or in_cloud_console()
def rmtree_with_retry(path):
# A workaround for https://bugs.python.org/issue33240
# Retry shutil.rmtree several times, but even if it fails after several retries, don't block the command execution.
retry_num = 3
import time
while True:
try:
import shutil
shutil.rmtree(path)
return
except FileNotFoundError:
# The folder has already been deleted. No further retry is needed.
# errno: 2, winerror: 3, strerror: 'The system cannot find the path specified'
return
except OSError as err:
if retry_num > 0:
logger.warning("Failed to delete '%s': %s. Retrying ...", path, err)
retry_num -= 1
time.sleep(1)
else:
logger.warning("Failed to delete '%s': %s. You may try to delete it manually.", path, err)
break
def get_secret_store(cli_ctx, name):
"""Create a process-concurrency-safe azure.cli.core.auth.persistence.SecretStore instance that can be used to
save secret data.
"""
from azure.cli.core._environment import get_config_dir
from azure.cli.core.auth.persistence import load_secret_store
# Save to CLI's config dir, by default ~/.azure
location = os.path.join(get_config_dir(), name)
# We honor the system type (Windows, Linux, or MacOS) and global config
encrypt = should_encrypt_token_cache(cli_ctx)
return load_secret_store(location, encrypt)
def should_encrypt_token_cache(cli_ctx):
# Only enable encryption for Windows (for now).
fallback = sys.platform.startswith('win32')
# EXPERIMENTAL: Use core.encrypt_token_cache=False to turn off token cache encryption.
# encrypt_token_cache affects both MSAL token cache and service principal entries.
encrypt = cli_ctx.config.getboolean('core', 'encrypt_token_cache', fallback=fallback)
return encrypt
def run_cmd(args, *, capture_output=False, timeout=None, check=False, encoding=None, env=None):
"""Run command in a subprocess. It reduces (not eliminates) shell injection by forcing args to be a list
and shell=False. Other arguments are keyword-only. For their documentation, see
https://docs.python.org/3/library/subprocess.html#subprocess.run
"""
if not isinstance(args, list):
from azure.cli.core.azclierror import ArgumentUsageError
raise ArgumentUsageError("Invalid args. run_cmd args must be a list")
import subprocess
return subprocess.run(args, capture_output=capture_output, timeout=timeout, check=check,
encoding=encoding, env=env)
def run_az_cmd(args, out_file=None):
"""
run_az_cmd would run az related cmds during command execution
:param args: cmd to be executed, array of string, like `["az", "version"]`, "az" is optional
:param out_file: The file to send output to. file-like object
:return: cmd execution result object, containing `result`, `error`, `exit_code`
"""
from azure.cli.core.azclierror import ArgumentUsageError
if not isinstance(args, list):
raise ArgumentUsageError("Invalid args. run_az_cmd args must be a list")
if args[0] == "az":
args = args[1:]
from azure.cli.core import get_default_cli
cli = get_default_cli()
cli.invoke(args, out_file=out_file)
return cli.result
def getprop(o, name, *default):
""" This function is used to get the property of the object.
It will raise an error if the property is a private property or a method.
"""
if name.startswith('_'):
# avoid to access the private properties or methods
raise AttributeError(name)
v = getattr(o, name, *default)
if callable(v):
# avoid to access the methods
raise AttributeError(name)
return v