azdev/operations/linter/linter.py (539 lines of code) (raw):
# -----------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -----------------------------------------------------------------------------
# pylint: disable=line-too-long
from difflib import context_diff
from enum import Enum
from importlib import import_module
import inspect
import os
import re
from pkgutil import iter_modules
from typing import List, Tuple
import yaml
from knack.log import get_logger
from azdev.operations.regex import (
get_all_tested_commands_from_regex,
search_aaz_raw_command, search_aaz_custom_command,
search_argument,
search_argument_context,
search_command,
search_deleted_command,
search_command_group)
from azdev.utilities import diff_branches_detail, diff_branch_file_patch
from azdev.utilities.path import get_cli_repo_path, get_ext_repo_paths
from .util import (share_element, exclude_commands, LinterError, get_cmd_example_configurations,
get_cmd_example_threshold)
PACKAGE_NAME = 'azdev.operations.linter'
_logger = get_logger(__name__)
class LinterSeverity(Enum):
HIGH = 2
MEDIUM = 1
LOW = 0
@staticmethod
def get_linter_severity(severity_name):
for severity in LinterSeverity:
if severity_name.lower() == severity.name.lower():
return severity
raise ValueError("Severity must be a valid linter severity name or value.")
@staticmethod
def get_ordered_members():
return sorted(LinterSeverity, key=lambda sev: sev.value)
class Linter: # pylint: disable=too-many-public-methods, too-many-instance-attributes
def __init__(self, command_loader=None, help_file_entries=None, loaded_help=None, git_source=None, git_target=None,
git_repo=None, exclusions=None):
self._all_yaml_help = help_file_entries
self._loaded_help = loaded_help
self._command_loader = command_loader
self._parameters = {}
self._help_file_entries = set(help_file_entries.keys())
self._command_parser = command_loader.cli_ctx.invocation.parser
self._command_groups = []
for command_name, command in self._command_loader.command_table.items():
self._parameters[command_name] = set()
for name in command.arguments:
self._parameters[command_name].add(name)
self.git_source = git_source
self.git_target = git_target
self.git_repo = git_repo
self.exclusions = exclusions
self.diffed_lines = set()
self._get_diffed_patches()
@property
def commands(self):
return self._command_loader.command_table.keys()
@property
def command_groups(self):
if not self._command_groups:
added_command_groups = set()
for command_group in self._command_loader.command_group_table.keys():
prefix_name = ""
for word in command_group.split():
prefix_name = "{} {}".format(prefix_name, word).strip()
if prefix_name in added_command_groups:
# if the parent command group is added continue
continue
added_command_groups.add(prefix_name)
self._command_groups.append(prefix_name)
return self._command_groups
@property
def help_file_entries(self):
return self._help_file_entries
@property
def command_parser(self):
return self._command_parser
@property
def command_loader_map(self):
return self._command_loader.cmd_to_loader_map
def get_command_metadata(self, command_name):
try:
return self._command_loader.command_table[command_name]
except KeyError:
return None
def get_command_parameters(self, command_name):
return self._parameters.get(command_name)
def get_command_group_metadata(self, command_group_name):
try:
return self._command_loader.command_group_table[command_group_name]
except KeyError:
return None
def get_help_entry_type(self, entry_name):
return self._all_yaml_help.get(entry_name).get('type')
def get_help_entry_examples(self, entry_name):
return self._all_yaml_help.get(entry_name).get('examples', [])
def get_help_entry_parameter_names(self, entry_name):
return [param_help.get('name', None) for param_help in
self._all_yaml_help.get(entry_name).get('parameters', [])]
def is_valid_parameter_help_name(self, entry_name, param_name):
return param_name in [param.name for param in getattr(self._loaded_help.get(entry_name), 'parameters', [])]
def get_command_help(self, command_name):
return self._get_loaded_help_description(command_name)
def get_command_group_help(self, command_group_name):
return self._get_loaded_help_description(command_group_name)
def get_parameter_options(self, command_name, parameter_name):
return self.get_command_metadata(command_name).arguments.get(parameter_name).type.settings.get('options_list')
def get_parameter_help(self, command_name, parameter_name):
options = self.get_parameter_options(command_name, parameter_name)
command_help = self._loaded_help.get(command_name, None)
if not command_help:
return None
parameter_helps = command_help.parameters
param_help = next((param for param in parameter_helps if share_element(options, param.name.split())), None)
# workaround for --ids which is not does not generate doc help (BUG)
if not param_help:
command_args = self._command_loader.command_table.get(command_name).arguments
return command_args.get(parameter_name).type.settings.get('help')
return param_help.short_summary or param_help.long_summary
def get_parameter_settings(self, command_name, parameter_name):
return self.get_command_metadata(command_name).arguments.get(parameter_name).type.settings
def get_parameter_help_info(self, command_name, parameter_name):
options = self.get_parameter_options(command_name, parameter_name)
command_help = self._loaded_help.get(command_name, None)
if not command_help:
return None
parameter_helps = command_help.parameters
param_help = next((param for param in parameter_helps if share_element(options, param.name.split())), None)
return param_help
def command_expired(self, command_name):
deprecate_info = self._command_loader.command_table[command_name].deprecate_info
if deprecate_info:
return deprecate_info.expired()
return False
def command_group_expired(self, command_group_name):
try:
group_kwargs = self._command_loader.command_group_table[command_group_name].group_kwargs
deprecate_info = group_kwargs.get('deprecate_info', None)
if deprecate_info:
return deprecate_info.expired()
except KeyError:
# ignore command_group_name which is not in command_group_table.
pass
except AttributeError:
# Items with only token presence in the command table will not have any data. They can't be expired.
pass
return False
def parameter_expired(self, command_name, parameter_name):
parameter = self._command_loader.command_table[command_name].arguments[parameter_name].type.settings
deprecate_info = parameter.get('deprecate_info', None)
if deprecate_info:
return deprecate_info.expired()
return False
def option_expired(self, command_name, parameter_name):
from knack.deprecation import Deprecated
parameter = self._command_loader.command_table[command_name].arguments[parameter_name].type.settings
options_list = parameter.get('options_list', [])
expired_options_list = []
for opt in options_list:
if isinstance(opt, Deprecated) and opt.expired():
expired_options_list.append(opt.target)
return expired_options_list
def _get_loaded_help_description(self, entry):
help_entry = self._loaded_help.get(entry, None)
if help_entry:
return help_entry.short_summary or help_entry.long_summary
return help_entry
def get_loaded_help_entry(self, entry):
help_entry = self._loaded_help.get(entry, None)
return help_entry
def get_command_test_coverage(self):
diff_index = diff_branches_detail(repo=self.git_repo, target=self.git_target, source=self.git_source)
commands, _ = self._detect_new_command(diff_index)
all_tested_command = self._detect_tested_command(diff_index)
return self._run_command_test_coverage(commands, all_tested_command)
def get_parameter_test_coverage(self):
diff_index = diff_branches_detail(repo=self.git_repo, target=self.git_target, source=self.git_source)
_, parameters = self._detect_new_command(diff_index)
all_tested_command = self._detect_tested_command(diff_index)
return self._run_parameter_test_coverage(parameters, all_tested_command)
def check_missing_command_example(self):
_exclude_commands = self._get_cmd_exclusions(rule_name="missing_command_example")
cmd_example_config = get_cmd_example_configurations()
commands = self._detect_modified_command()
violations = []
for cmd in commands:
if cmd in _exclude_commands:
continue
cmd_help = self._loaded_help.get(cmd, None)
if not cmd_help:
continue
cmd_suffix = cmd.split()[-1]
cmd_example_threshold = get_cmd_example_threshold(cmd_suffix, cmd_example_config)
if cmd_example_threshold == 0:
continue
if not hasattr(cmd_help, "parameters") or len(cmd_help.parameters) == 0:
# skip cmd without parameters
continue
if not hasattr(cmd_help, "examples") or len(cmd_help.examples) < cmd_example_threshold:
violations.append(f'Command `{cmd}` should have at least {cmd_example_threshold} example(s)')
if violations:
violations.insert(0, 'Check command example failed.')
violations.extend(['Please add examples for the modified command or add the command in rule_exclusions: missing_command_example in linter_exclusions.yml'])
return violations
def _get_exclusions(self):
_exclude_commands = set()
_exclude_parameters = set()
for command, details in self.exclusions.items():
if 'parameters' in details:
for param, rules in details['parameters'].items():
if 'missing_parameter_test_coverage' in rules['rule_exclusions']:
_exclude_parameters.add((command, param))
if 'rule_exclusions' in details and 'missing_command_test_coverage' in details['rule_exclusions']:
_exclude_commands.add(command)
_logger.debug('exclude_parameters: %s', _exclude_parameters)
_logger.debug('exclude_comands: %s', _exclude_commands)
return _exclude_commands, _exclude_parameters
def _get_cmd_exclusions(self, rule_name=None):
_exclude_commands = set()
if not rule_name:
return _exclude_commands
for command, details in self.exclusions.items():
if 'rule_exclusions' in details and rule_name in details['rule_exclusions']:
_exclude_commands.add(command)
_logger.debug('exclude_commands: %s', _exclude_commands)
return _exclude_commands
def _split_path(self, path: str):
parts = path.rsplit('/', maxsplit=1)
return parts if len(parts) == 2 else ('', parts[0])
def _read_blob_lines(self, blob):
return blob.data_stream.read().decode("utf-8").splitlines(True) if blob else []
def _get_line_number(self, lines: List[str], row_num: int, pattern: str):
offset = -1
while row_num > 0:
row_num -= 1
match = re.findall(pattern, lines[row_num])
offset += 1
if match:
return int(match[0]) + offset
return -1
def _extract_parameters(self, lines, current_lines, _exclude_commands, _exclude_parameters, parameters):
for row_num, line in enumerate(lines):
params, param_name = search_argument(line)
if params:
idx = self._get_line_number(lines, row_num, r'--- (\d+),(?:\d+) ----')
commands = search_argument_context(idx, current_lines)
for cmd in commands:
if cmd not in _exclude_commands and (cmd, param_name) not in _exclude_parameters:
parameters.append((cmd, params))
_logger.debug('Detected parameter: [%s, %s]', cmd, params)
return parameters
def _extract_commands(self, lines, original_lines, current_lines, added_commands,
deleted_commands, _exclude_commands, yellow_color):
for row_num, line in enumerate(lines):
added_command = search_command(line)
deleted_command = search_deleted_command(line)
if added_command:
idx = self._get_line_number(lines, row_num, r'--- (\d+),(?:\d+) ----')
cmd = search_command_group(idx, current_lines, added_command)
if cmd:
if cmd in _exclude_commands:
_logger.warning('%sCommand %s not tested and excluded in linter_exclusions.yml',
yellow_color, cmd)
else:
added_commands.add(cmd)
elif deleted_command:
idx = self._get_line_number(lines, row_num, r'\*\*\* (\d+),(?:\d+) \*\*\*\*')
cmd = search_command_group(idx, original_lines, deleted_command)
if cmd:
deleted_commands.add(cmd)
return added_commands, deleted_commands
def _detect_new_command(self, diff_index: List) -> Tuple[List[str], List[Tuple[str, str]]]:
YELLOW = '\x1b[33m'
_exclude_commands, _exclude_parameters = self._get_exclusions()
added_commands, deleted_commands, parameters = set(), set(), []
for diff in diff_index:
_, filename = self._split_path(diff.a_path)
if not any(key in filename for key in ('params', 'commands')):
continue
original_lines = self._read_blob_lines(diff.a_blob)
current_lines = self._read_blob_lines(diff.b_blob)
lines = list(context_diff(original_lines, current_lines, 'Original', 'Current'))
if 'params.py' in filename:
parameters = self._extract_parameters(lines, current_lines, _exclude_commands,
_exclude_parameters, parameters)
if 'commands.py' in filename:
added_commands, deleted_commands = self._extract_commands(lines, original_lines, current_lines,
added_commands, deleted_commands,
_exclude_commands, YELLOW)
commands = list(added_commands - deleted_commands)
_logger.debug('New parameters: %s', parameters)
_logger.debug('Added commands: %s', added_commands)
_logger.debug('Deleted commands: %s', deleted_commands)
_logger.debug('Final commands: %s', commands)
return commands, parameters
def _detect_tested_command(self, diff_index):
all_tested_command = []
# get tested command by regex
for diff in diff_index:
filename = diff.a_path.split('/')[-1]
if re.findall(r'^test_.*\.py$', filename) and \
os.path.exists(os.path.join(self.git_repo, diff.a_path)):
with open(os.path.join(self.git_repo, diff.a_path), encoding='utf-8') as f:
lines = f.readlines()
ref = get_all_tested_commands_from_regex(lines)
all_tested_command += ref
# get tested command by recording file
if re.findall(r'^test_.*\.yaml$', filename) and \
os.path.exists(os.path.join(self.git_repo, diff.a_path)):
with open(os.path.join(self.git_repo, diff.a_path)) as f:
records = yaml.load(f, Loader=yaml.Loader) or {}
for record in records['interactions']:
# parse command ['acr agentpool create']
command = record['request']['headers'].get('CommandName', [''])[0]
# parse argument ['-n -r']
argument = record['request']['headers'].get('ParameterSetName', [''])[0]
if command or argument:
all_tested_command.append(command + ' ' + argument)
_logger.debug('All tested command: %s', all_tested_command)
return all_tested_command
@staticmethod
def _run_command_test_coverage(commands, all_tested_command):
exec_state = True
violations = []
for command in commands:
for code in all_tested_command:
if command in code:
break
else:
violations.append(f'Missing command test coverage: `{command}`')
exec_state = False
if violations:
violations.insert(0, 'Failed.')
violations.extend([
'Please add some scenario tests for the new command',
'Or add the command with missing_command_test_coverage rule in linter_exclusions.yml'])
return exec_state, violations
@staticmethod
def _run_parameter_test_coverage(parameters, all_tested_command):
flag = False
exec_state = True
violations = []
for command, opt_list in parameters:
for opt in opt_list:
for code in all_tested_command:
if command in code and opt in code:
flag = True
break
if flag:
break
else:
violations.append(f'Missing parameter test coverage: `{command} {opt}`')
exec_state = False
if violations:
violations.insert(0, 'Failed.')
violations.extend([
'Please add some scenario tests for the new parameter',
'Or add the parameter with missing_parameter_test_coverage rule in linter_exclusions.yml'])
return exec_state, violations
def _detect_modified_command(self):
modified_commands = set()
diff_patches = diff_branch_file_patch(repo=self.git_repo, target=self.git_target, source=self.git_source)
for change in diff_patches:
if not change.b_path or not change.diff:
continue
file_path, filename = self._split_path(change.b_path)
if "commands.py" not in filename and "aaz" not in file_path:
continue
current_lines = self._read_blob_lines(change.b_blob)
patch = change.diff.decode("utf-8")
patch_lines = patch.splitlines()
if 'commands.py' in filename:
added_lines = [line for line in patch_lines if line.startswith('+') and not line.startswith('+++')]
for line in added_lines:
if aaz_custom_command := search_aaz_custom_command(line):
modified_commands.add(aaz_custom_command)
for row_num, line in enumerate(patch_lines):
if not line.startswith("+") or line.startswith('+++'):
continue
manual_command_suffix = search_command(line)
if manual_command_suffix:
idx = self._get_line_number(patch_lines, row_num, r'@@ -(\d+),(?:\d+) \+(?:\d+),(?:\d+) @@')
manual_command = search_command_group(idx, current_lines, manual_command_suffix)
if manual_command:
modified_commands.add(manual_command)
if "aaz" in file_path:
if aaz_raw_command := search_aaz_raw_command(patch):
modified_commands.add(aaz_raw_command)
commands = list(modified_commands)
_logger.debug('Modified commands: %s', modified_commands)
return commands
def _get_diffed_patches(self):
if not self.git_source or not self.git_target or not self.git_repo:
return
diff_patches = diff_branch_file_patch(repo=self.git_repo, target=self.git_target, source=self.git_source)
for change in diff_patches:
if not change.diff:
continue
patch = change.diff.decode("utf-8")
added_lines = [line for line in patch.splitlines() if line.startswith('+') and not line.startswith('+++')]
self.diffed_lines |= set(added_lines)
if added_lines:
_logger.info("Changes in file '%s':", change.a_path)
for line in added_lines:
_logger.info(line)
# pylint: disable=too-many-instance-attributes
class LinterManager:
_RULE_TYPES = {'help_file_entries', 'command_groups', 'commands', 'params', 'command_test_coverage'}
def __init__(self, command_loader=None, help_file_entries=None, loaded_help=None, exclusions=None,
rule_inclusions=None, use_ci_exclusions=None, min_severity=None, update_global_exclusion=None,
git_source=None, git_target=None, git_repo=None):
# default to running only rules of the highest severity
self.min_severity = min_severity or LinterSeverity.get_ordered_members()[-1]
self._exclusions = exclusions or {}
self.linter = Linter(command_loader=command_loader, help_file_entries=help_file_entries,
loaded_help=loaded_help, git_source=git_source, git_target=git_target, git_repo=git_repo,
exclusions=self._exclusions)
self._rules = {rule_type: {} for rule_type in LinterManager._RULE_TYPES} # initialize empty rules
self._ci_exclusions = {}
self._rule_inclusions = rule_inclusions
self._loaded_help = loaded_help
self._command_loader = command_loader
self._help_file_entries = help_file_entries
self._exit_code = 0
self._ci = use_ci_exclusions if use_ci_exclusions is not None else os.environ.get('CI', False)
self._violiations = {}
self._update_global_exclusion = update_global_exclusion
def add_rule(self, rule_type, rule_name, rule_callable, rule_severity):
include_rule = not self._rule_inclusions or rule_name in self._rule_inclusions
if rule_type in self._rules and include_rule:
def get_linter():
# if a rule has exclusions return a linter that factors in those exclusions
# otherwise return the main linter.
if rule_name in self._ci_exclusions and self._ci:
mod_exclusions = self._ci_exclusions[rule_name]
command_loader, help_file_entries = exclude_commands(
self._command_loader,
self._help_file_entries,
mod_exclusions)
return Linter(command_loader=command_loader, help_file_entries=help_file_entries,
loaded_help=self._loaded_help)
return self.linter
self._rules[rule_type][rule_name] = rule_callable, get_linter, rule_severity
def mark_rule_failure(self, rule_severity):
if rule_severity is LinterSeverity.HIGH:
self._exit_code = 1
@property
def exclusions(self):
return self._exclusions
@property
def exit_code(self):
return self._exit_code
def run(self, run_params=None, run_commands=None, run_command_groups=None,
run_help_files_entries=None, run_command_test_coverage=None):
paths = import_module('{}.rules'.format(PACKAGE_NAME)).__path__
if paths:
ci_exclusions_path = os.path.join(paths[0], 'ci_exclusions.yml')
with open(ci_exclusions_path) as f:
self._ci_exclusions = yaml.safe_load(f) or {}
# find all defined rules and check for name conflicts
found_rules = set()
for _, name, _ in iter_modules(paths):
rule_module = import_module('{}.rules.{}'.format(PACKAGE_NAME, name))
functions = inspect.getmembers(rule_module, inspect.isfunction)
for rule_name, add_to_linter_func in functions:
if hasattr(add_to_linter_func, 'linter_rule'):
if rule_name in found_rules:
raise LinterError('Multiple rules found with the same name: %s' % rule_name)
found_rules.add(rule_name)
add_to_linter_func(self)
# run all rule-checks
if run_help_files_entries and self._rules.get('help_file_entries'):
# print('help_file_entries')
self._run_rules('help_file_entries')
if run_command_groups and self._rules.get('command_groups'):
# print('command_groups')
self._run_rules('command_groups')
if run_commands and self._rules.get('commands'):
self._run_rules('commands')
if run_params and self._rules.get('params'):
self._run_rules('params')
if run_command_test_coverage and self._rules.get('command_test_coverage'):
self._run_rules('command_test_coverage')
if not self.exit_code:
print(os.linesep + 'No violations found for linter rules.')
if self._update_global_exclusion is not None:
if self._update_global_exclusion == 'CLI':
repo_paths = [get_cli_repo_path()]
else:
repo_paths = get_ext_repo_paths()
exclusion_paths = [os.path.join(repo_path, 'linter_exclusions.yml') for repo_path in repo_paths]
for exclusion_path in exclusion_paths:
if not os.path.isfile(exclusion_path):
with open(exclusion_path, 'a'):
pass
with open(exclusion_path) as f:
exclusions = yaml.safe_load(f) or {}
exclusions.update(self._violiations)
with open(exclusion_path, 'w') as f:
yaml.safe_dump(exclusions, f)
return self.exit_code
def _run_rules(self, rule_group):
# https://docs.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences#text-formatting
RED = '\x1b[31m'
GREEN = '\x1b[32m'
YELLOW = '\x1b[33m'
CYAN = '\x1b[36m'
RESET = '\x1b[39m'
# print('enter _run_rules')
for rule_name, (rule_func, linter_callable, rule_severity) in self._rules.get(rule_group).items():
# print('enter_items')
severity_str = rule_severity.name
# use new linter if needed
with LinterScope(self, linter_callable):
# print('enter_with')
# if the rule's severity is lower than the linter's severity skip it.
if self._linter_severity_is_applicable(rule_severity, rule_name):
# print('enter violations', rule_func)
violations = sorted(rule_func()) or []
# print('enter to find')
if violations:
if rule_severity == LinterSeverity.HIGH:
sev_color = RED
elif rule_severity == LinterSeverity.MEDIUM:
sev_color = YELLOW
else:
sev_color = CYAN
# pylint: disable=duplicate-string-formatting-argument
print('- {} FAIL{} - {}{}{} severity: {}'.format(RED, RESET, sev_color,
severity_str, RESET, rule_name, ))
for violation_msg, entity_name, name in violations:
print(violation_msg)
self._save_violations(entity_name, name)
print()
else:
print('- {} pass{}: {} '.format(GREEN, RESET, rule_name))
# print('enter_end')
def _linter_severity_is_applicable(self, rule_severity, rule_name):
if self.min_severity.value > rule_severity.value:
_logger.info("Skipping rule %s, because its severity '%s' is lower than the linter's min severity of '%s'.",
rule_name, rule_severity.name, self.min_severity.value)
return False
return True
# pylint: disable=line-too-long
def _save_violations(self, entity_name, rule_name):
if isinstance(entity_name, str):
command_name = entity_name
self._violiations.setdefault(command_name, {}).setdefault('rule_exclusions', []).append(rule_name)
else:
command_name, param_name = entity_name
self._violiations.setdefault(command_name, {}).setdefault('parameters', {}).setdefault(param_name,
{}).setdefault(
'rule_exclusions', []).append(rule_name)
class RuleError(Exception):
"""
Exception thrown by rule violation
"""
pass # pylint: disable=unnecessary-pass
class LinterScope:
"""
Linter Context manager. used when calling a rule function. Allows substitution of main linter for a linter
that takes into account any applicable exclusions, if applicable.
"""
def __init__(self, linter_manager, linter_callable):
self.linter_manager = linter_manager
self.linter = linter_callable()
self.main_linter = linter_manager.linter
def __enter__(self):
self.linter_manager.linter = self.linter
def __exit__(self, exc_type, value, traceback):
self.linter_manager.linter = self.main_linter