detection_rules/config.py (223 lines of code) (raw):

# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Configuration support for custom components.""" import fnmatch import os from dataclasses import dataclass, field from pathlib import Path from functools import cached_property from typing import Dict, List, Optional import yaml from eql.utils import load_dump from .misc import discover_tests from .utils import cached, load_etc_dump, get_etc_path, set_all_validation_bypass ROOT_DIR = Path(__file__).parent.parent CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None) @dataclass class UnitTest: """Base object for unit tests configuration.""" bypass: Optional[List[str]] = None test_only: Optional[List[str]] = None def __post_init__(self): assert (self.bypass is None or self.test_only is None), \ 'Cannot set both `test_only` and `bypass` in test_config!' @dataclass class RuleValidation: """Base object for rule validation configuration.""" bypass: Optional[List[str]] = None test_only: Optional[List[str]] = None def __post_init__(self): assert not (self.bypass and self.test_only), 'Cannot use both test_only and bypass' @dataclass class ConfigFile: """Base object for configuration files.""" @dataclass class FilePaths: packages_file: str stack_schema_map_file: str deprecated_rules_file: Optional[str] = None version_lock_file: Optional[str] = None @dataclass class TestConfigPath: config: str files: FilePaths rule_dir: List[str] testing: Optional[TestConfigPath] = None @classmethod def from_dict(cls, obj: dict) -> 'ConfigFile': files_data = obj.get('files', {}) files = cls.FilePaths( deprecated_rules_file=files_data.get('deprecated_rules'), packages_file=files_data['packages'], stack_schema_map_file=files_data['stack_schema_map'], version_lock_file=files_data.get('version_lock') ) rule_dir = obj['rule_dirs'] testing_data = obj.get('testing') testing = cls.TestConfigPath( config=testing_data['config'] ) if testing_data else None return cls(files=files, rule_dir=rule_dir, testing=testing) @dataclass class TestConfig: """Detection rules test config file""" test_file: Optional[Path] = None unit_tests: Optional[UnitTest] = None rule_validation: Optional[RuleValidation] = None @classmethod def from_dict(cls, test_file: Optional[Path] = None, unit_tests: Optional[dict] = None, rule_validation: Optional[dict] = None) -> 'TestConfig': return cls(test_file=test_file or None, unit_tests=UnitTest(**unit_tests or {}), rule_validation=RuleValidation(**rule_validation or {})) @cached_property def all_tests(self): """Get the list of all test names.""" return discover_tests() def tests_by_patterns(self, *patterns: str) -> List[str]: """Get the list of test names by patterns.""" tests = set() for pattern in patterns: tests.update(list(fnmatch.filter(self.all_tests, pattern))) return sorted(tests) @staticmethod def parse_out_patterns(names: List[str]) -> (List[str], List[str]): """Parse out test patterns from a list of test names.""" patterns = [] tests = [] for name in names: if name.startswith('pattern:') and '*' in name: patterns.append(name[len('pattern:'):]) else: tests.append(name) return patterns, tests @staticmethod def format_tests(tests: List[str]) -> List[str]: """Format unit test names into expected format for direct calling.""" raw = [t.rsplit('.', maxsplit=2) for t in tests] formatted = [] for test in raw: path, clazz, method = test path = f'{path.replace(".", os.path.sep)}.py' formatted.append('::'.join([path, clazz, method])) return formatted def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): """Get the list of test names to run.""" patterns_t, tests_t = self.parse_out_patterns(self.unit_tests.test_only or []) patterns_b, tests_b = self.parse_out_patterns(self.unit_tests.bypass or []) defined_tests = tests_t + tests_b patterns = patterns_t + patterns_b unknowns = sorted(set(defined_tests) - set(self.all_tests)) assert not unknowns, f'Unrecognized test names in config ({self.test_file}): {unknowns}' combined_tests = sorted(set(defined_tests + self.tests_by_patterns(*patterns))) if self.unit_tests.test_only is not None: tests = combined_tests skipped = [t for t in self.all_tests if t not in tests] elif self.unit_tests.bypass: tests = [] skipped = [] for test in self.all_tests: if test not in combined_tests: tests.append(test) else: skipped.append(test) else: tests = self.all_tests skipped = [] if formatted: return self.format_tests(tests), self.format_tests(skipped) else: return tests, skipped def check_skip_by_rule_id(self, rule_id: str) -> bool: """Check if a rule_id should be skipped.""" bypass = self.rule_validation.bypass test_only = self.rule_validation.test_only # neither bypass nor test_only are defined, so no rules are skipped if not (bypass or test_only): return False # if defined in bypass or not defined in test_only, then skip return (bypass and rule_id in bypass) or (test_only and rule_id not in test_only) @dataclass class RulesConfig: """Detection rules config file.""" deprecated_rules_file: Path deprecated_rules: Dict[str, dict] packages_file: Path packages: Dict[str, dict] rule_dirs: List[Path] stack_schema_map_file: Path stack_schema_map: Dict[str, dict] test_config: TestConfig version_lock_file: Path version_lock: Dict[str, dict] action_dir: Optional[Path] = None action_connector_dir: Optional[Path] = None auto_gen_schema_file: Optional[Path] = None bbr_rules_dirs: Optional[List[Path]] = field(default_factory=list) bypass_version_lock: bool = False exception_dir: Optional[Path] = None normalize_kql_keywords: bool = True bypass_optional_elastic_validation: bool = False no_tactic_filename: bool = False def __post_init__(self): """Perform post validation on packages.yaml file.""" if 'package' not in self.packages: raise ValueError('Missing the `package` field defined in packages.yaml.') if 'name' not in self.packages['package']: raise ValueError('Missing the `name` field defined in packages.yaml.') @cached def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: """Parse the _config.yaml file for default or custom rules.""" if path: assert path.exists(), f'rules config file does not exist: {path}' loaded = yaml.safe_load(path.read_text()) elif CUSTOM_RULES_DIR: path = Path(CUSTOM_RULES_DIR) / '_config.yaml' if not path.exists(): raise FileNotFoundError( """ Configuration file not found. Please create a configuration file. You can use the 'custom-rules setup-config' command and update the 'CUSTOM_RULES_DIR' environment variable as needed. """ ) loaded = yaml.safe_load(path.read_text()) else: path = Path(get_etc_path('_config.yaml')) loaded = load_etc_dump('_config.yaml') try: ConfigFile.from_dict(loaded) except KeyError as e: raise SystemExit(f'Missing key `{str(e)}` in _config.yaml file.') except (AttributeError, TypeError): raise SystemExit(f'No data properly loaded from {path}') except ValueError as e: raise SystemExit(e) base_dir = path.resolve().parent # testing # precedence to the environment variable # environment variable is absolute path and config file is relative to the _config.yaml file test_config_ev = os.getenv('DETECTION_RULES_TEST_CONFIG', None) if test_config_ev: test_config_path = Path(test_config_ev) else: test_config_file = loaded.get('testing', {}).get('config') if test_config_file: test_config_path = base_dir.joinpath(test_config_file) else: test_config_path = None if test_config_path: test_config_data = yaml.safe_load(test_config_path.read_text()) # overwrite None with empty list to allow implicit exemption of all tests with `test_only` defined to None in # test config if 'unit_tests' in test_config_data and test_config_data['unit_tests'] is not None: test_config_data['unit_tests'] = {k: v or [] for k, v in test_config_data['unit_tests'].items()} test_config = TestConfig.from_dict(test_file=test_config_path, **test_config_data) else: test_config = TestConfig.from_dict() # files # paths are relative files = {f'{k}_file': base_dir.joinpath(v) for k, v in loaded['files'].items()} contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded['files'].items()} contents.update(**files) # directories # paths are relative if loaded.get('directories'): contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded['directories'].items()}) # rule_dirs # paths are relative contents['rule_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('rule_dirs')] # directories # paths are relative if loaded.get('directories'): directories = loaded.get('directories') if directories.get('exception_dir'): contents['exception_dir'] = base_dir.joinpath(directories.get('exception_dir')).resolve() if directories.get('action_dir'): contents['action_dir'] = base_dir.joinpath(directories.get('action_dir')).resolve() if directories.get('action_connector_dir'): contents['action_connector_dir'] = base_dir.joinpath(directories.get('action_connector_dir')).resolve() # version strategy contents['bypass_version_lock'] = loaded.get('bypass_version_lock', False) # bbr_rules_dirs # paths are relative if loaded.get('bbr_rules_dirs'): contents['bbr_rules_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('bbr_rules_dirs', [])] # kql keyword normalization contents['normalize_kql_keywords'] = loaded.get('normalize_kql_keywords', True) if loaded.get('auto_gen_schema_file'): contents['auto_gen_schema_file'] = base_dir.joinpath(loaded['auto_gen_schema_file']) # Check if the file exists if not contents['auto_gen_schema_file'].exists(): # If the file doesn't exist, create the necessary directories and file contents['auto_gen_schema_file'].parent.mkdir(parents=True, exist_ok=True) contents['auto_gen_schema_file'].write_text('{}') # bypass_optional_elastic_validation contents['bypass_optional_elastic_validation'] = loaded.get('bypass_optional_elastic_validation', False) if contents['bypass_optional_elastic_validation']: set_all_validation_bypass(contents['bypass_optional_elastic_validation']) # no_tactic_filename contents['no_tactic_filename'] = loaded.get('no_tactic_filename', False) # return the config try: rules_config = RulesConfig(test_config=test_config, **contents) except (ValueError, TypeError) as e: raise SystemExit(f'Error parsing packages.yaml: {str(e)}') return rules_config @cached def load_current_package_version() -> str: """Load the current package version from config file.""" return parse_rules_config().packages['package']['name']