azdev/operations/code_gen.py (226 lines of code) (raw):
# -----------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -----------------------------------------------------------------------------
import json
import os
import re
from knack.log import get_logger
from knack.prompting import prompt_y_n, prompt
from knack.util import CLIError
from azdev.utilities import (
pip_cmd, display, heading, COMMAND_MODULE_PREFIX, EXTENSION_PREFIX, get_cli_repo_path, get_ext_repo_paths,
find_files)
logger = get_logger(__name__)
_MODULE_ROOT_PATH = os.path.join('src', 'azure-cli', 'azure', 'cli', 'command_modules')
def _ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def _generate_files(env, generation_kwargs, file_list, dest_path):
# allow sending a single item
if not isinstance(file_list, list):
file_list = [file_list]
for metadata in file_list:
# shortcut if source and dest filenames are the same
if isinstance(metadata, str):
metadata = {'name': metadata, 'template': metadata}
with open(os.path.join(dest_path, metadata['name']), 'w') as f:
f.write(env.get_template(metadata['template']).render(**generation_kwargs))
def create_module(mod_name='test', display_name=None, display_name_plural=None, required_sdk=None,
client_name=None, operation_name=None, sdk_property=None, not_preview=False, github_alias=None,
local_sdk=None):
repo_path = os.path.join(get_cli_repo_path(), _MODULE_ROOT_PATH)
_create_package('', repo_path, False, mod_name, display_name, display_name_plural,
required_sdk, client_name, operation_name, sdk_property, not_preview, local_sdk)
_add_to_codeowners(get_cli_repo_path(), '', mod_name, github_alias)
_add_to_doc_map(get_cli_repo_path(), mod_name)
_display_success_message(COMMAND_MODULE_PREFIX + mod_name, mod_name)
def create_extension(ext_name='test', repo_name='azure-cli-extensions',
display_name=None, display_name_plural=None,
required_sdk=None, client_name=None, operation_name=None, sdk_property=None,
not_preview=False, github_alias=None, local_sdk=None):
repo_path = None
repo_paths = get_ext_repo_paths()
repo_path = next((x for x in repo_paths if x.endswith(repo_name)), None)
if not repo_path:
raise CLIError('Unable to find `{}` repo. Have you cloned it and added '
'with `azdev extension repo add`?'.format(repo_name))
_create_package(EXTENSION_PREFIX, os.path.join(repo_path, 'src'), True, ext_name, display_name,
display_name_plural, required_sdk, client_name, operation_name, sdk_property, not_preview,
local_sdk)
_add_to_codeowners(repo_path, EXTENSION_PREFIX, ext_name, github_alias)
_display_success_message(EXTENSION_PREFIX + ext_name, ext_name)
def _display_success_message(package_name, group_name):
heading('Creation of {} successful!'.format(package_name))
display('Getting started:')
display('\n To see your new commands:')
display(' `az {} -h`'.format(group_name))
display('\n To discover and run your tests:')
display(' `azdev test {} --discover`'.format(group_name))
display('\n To identify code style issues (there will be some left over from code generation):')
display(' `azdev style {}`'.format(group_name))
display('\n To identify CLI-specific linter violations:')
display(' `azdev linter {}`'.format(group_name))
def _download_vendored_sdk(required_sdk, path):
import tempfile
import zipfile
path_regex = re.compile(r'.*((\s*.*downloaded\s)|(\s*.*saved\s))(?P<path>.*\.whl)', re.IGNORECASE | re.S)
temp_path = tempfile.mkdtemp()
# download and extract the required SDK to the vendored_sdks folder
downloaded_path = None
if required_sdk:
display('Downloading {}...'.format(required_sdk))
vendored_sdks_path = path
result = pip_cmd('download {} --no-deps -d {}'.format(required_sdk, temp_path)).result
try:
result = result.decode('utf-8')
except AttributeError:
pass
for line in result.splitlines():
try:
downloaded_path = path_regex.match(line).group('path')
except AttributeError:
continue
break
if not downloaded_path:
display('Unable to download')
raise CLIError('Unable to download: {}'.format(required_sdk))
# extract the WHL file
with zipfile.ZipFile(str(downloaded_path), 'r') as z:
z.extractall(temp_path)
_copy_vendored_sdk(temp_path, vendored_sdks_path)
def _copy_vendored_sdk(src_path, dest_path):
import shutil
try:
version_files = find_files(src_path, 'version.py')
if not version_files:
version_files = find_files(src_path, '_version.py')
client_location = version_files[0]
except IndexError:
raise CLIError('Unable to find client files.')
# copy the client files and folders to the root of vendored_sdks for easier access
client_dir = os.path.dirname(client_location)
shutil.rmtree(dest_path)
shutil.copytree(client_dir, dest_path)
def _add_to_codeowners(repo_path, prefix, name, github_alias):
# add the user Github alias to the CODEOWNERS file for new packages
if not github_alias:
display('\nWhat is the Github alias of the person responsible for maintaining this package?')
while not github_alias:
github_alias = prompt('Alias: ')
# accept a raw alias or @alias
github_alias = '@{}'.format(github_alias) if not github_alias.startswith('@') else github_alias
try:
codeowners = find_files(repo_path, 'CODEOWNERS')[0]
except IndexError:
raise CLIError('unexpected error: unable to find CODEOWNERS file.')
if prefix == EXTENSION_PREFIX:
new_line = '/src/{}{}/ {}'.format(prefix, name, github_alias)
else:
# ensure Linux-style separators when run on Windows
new_line = '/{} {}'.format(os.path.join('', _MODULE_ROOT_PATH, name, ''), github_alias).replace('\\', '/')
with open(codeowners, 'a') as f:
f.write(new_line)
f.write('\n')
def _add_to_doc_map(repo_path, name):
try:
doc_source_file = find_files(repo_path, 'doc_source_map.json')[0]
except IndexError:
raise CLIError('unexpected error: unable to find doc_source_map.json file.')
doc_source = None
with open(doc_source_file, 'r') as f:
doc_source = json.loads(f.read())
# ensure Linux-style separators when run on Windows
doc_source[name] = str(os.path.join(_MODULE_ROOT_PATH, name, '_help.py')).replace('\\', '/')
with open(doc_source_file, 'w') as f:
f.write(json.dumps(doc_source, indent=4))
# pylint: disable=too-many-locals, too-many-statements, too-many-branches
def _create_package(prefix, repo_path, is_ext, name='test', display_name=None, display_name_plural=None,
required_sdk=None, client_name=None, operation_name=None, sdk_property=None,
not_preview=False, local_sdk=None):
from jinja2 import Environment, PackageLoader
if local_sdk and required_sdk:
raise CLIError('usage error: --local-sdk PATH | --required-sdk NAME==VER')
if name.startswith(prefix):
name = name[len(prefix):]
heading('Create CLI {}: {}{}'.format('Extension' if is_ext else 'Module', prefix, name))
# package_name is how the item should show up in `pip list`
package_name = '{}{}'.format(prefix, name.replace('_', '-')) if not is_ext else name
display_name = display_name or name.capitalize()
kwargs = {
'name': name,
'mod_path': '{}{}'.format(prefix, name) if is_ext else 'azure.cli.command_modules.{}'.format(name),
'display_name': display_name,
'display_name_plural': display_name_plural or '{}s'.format(display_name),
'loader_name': '{}CommandsLoader'.format(name.capitalize()),
'pkg_name': package_name,
'ext_long_name': '{}{}'.format(prefix, name) if is_ext else None,
'is_ext': is_ext,
'is_preview': not not_preview
}
new_package_path = os.path.join(repo_path, package_name)
if os.path.isdir(new_package_path):
if not prompt_y_n(
"{} '{}' already exists. Overwrite?".format('Extension' if is_ext else 'Module', package_name),
default='n'):
raise CLIError('aborted by user')
ext_folder = '{}{}'.format(prefix, name) if is_ext else None
# create folder tree
if is_ext:
_ensure_dir(os.path.join(new_package_path, ext_folder, 'tests', 'latest'))
_ensure_dir(os.path.join(new_package_path, ext_folder, 'vendored_sdks'))
else:
_ensure_dir(os.path.join(new_package_path, 'tests', 'latest'))
env = Environment(loader=PackageLoader('azdev', 'mod_templates'))
# determine dependencies
dependencies = []
if is_ext:
if required_sdk:
_download_vendored_sdk(
required_sdk,
path=os.path.join(new_package_path, ext_folder, 'vendored_sdks')
)
elif local_sdk:
_copy_vendored_sdk(local_sdk, os.path.join(new_package_path, ext_folder, 'vendored_sdks'))
sdk_path = None
if any([local_sdk, required_sdk]):
sdk_path = '{}{}.vendored_sdks'.format(prefix, package_name)
kwargs.update({
'sdk_path': sdk_path,
'client_name': client_name,
'operation_name': operation_name,
'sdk_property': sdk_property or '{}_name'.format(name)
})
else:
if required_sdk:
version_regex = r'(?P<name>[a-zA-Z-]+)(?P<op>[~<>=]*)(?P<version>[\d.]*)'
version_comps = re.compile(version_regex).match(required_sdk)
sdk_kwargs = version_comps.groupdict()
kwargs.update({
'sdk_path': sdk_kwargs['name'].replace('-', '.'),
'client_name': client_name,
'operation_name': operation_name,
})
dependencies.append("'{}'".format(required_sdk))
else:
dependencies.append('# TODO: azure-mgmt-<NAME>==<VERSION>')
kwargs.update({'sdk_property': sdk_property or '{}_name'.format(name)})
kwargs['dependencies'] = dependencies
# generate code for root level
dest_path = new_package_path
if is_ext:
root_files = [
'HISTORY.rst',
'README.rst',
'setup.cfg',
'setup.py'
]
_generate_files(env, kwargs, root_files, dest_path)
dest_path = dest_path if not is_ext else os.path.join(dest_path, ext_folder)
module_files = [
{'name': '__init__.py', 'template': 'module__init__.py'},
'_client_factory.py',
'_help.py',
'_params.py',
'_validators.py',
'commands.py',
'custom.py'
]
if is_ext:
module_files.append('azext_metadata.json')
_generate_files(env, kwargs, module_files, dest_path)
dest_path = os.path.join(dest_path, 'tests')
blank_init = {'name': '__init__.py', 'template': 'blank__init__.py'}
_generate_files(env, kwargs, blank_init, dest_path)
dest_path = os.path.join(dest_path, 'latest')
test_files = [
blank_init,
{'name': 'test_{}_scenario.py'.format(name), 'template': 'test_service_scenario.py'}
]
_generate_files(env, kwargs, test_files, dest_path)
if is_ext:
result = pip_cmd('install -e {}'.format(new_package_path), "Installing `{}{}`...".format(prefix, name))
if result.error:
raise result.error # pylint: disable=raising-bad-type