detection_rules/config.py (221 lines of code) (raw):
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Configuration support for custom components."""
import fnmatch
import os
from dataclasses import dataclass, field
from pathlib import Path
from functools import cached_property
from typing import Dict, List, Optional
import yaml
from eql.utils import load_dump
from .misc import discover_tests
from .utils import cached, load_etc_dump, get_etc_path, set_all_validation_bypass
ROOT_DIR = Path(__file__).parent.parent
CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None)
@dataclass
class UnitTest:
"""Base object for unit tests configuration."""
bypass: Optional[List[str]] = None
test_only: Optional[List[str]] = None
def __post_init__(self):
assert (self.bypass is None or self.test_only is None), \
'Cannot set both `test_only` and `bypass` in test_config!'
@dataclass
class RuleValidation:
"""Base object for rule validation configuration."""
bypass: Optional[List[str]] = None
test_only: Optional[List[str]] = None
def __post_init__(self):
assert not (self.bypass and self.test_only), 'Cannot use both test_only and bypass'
@dataclass
class ConfigFile:
"""Base object for configuration files."""
@dataclass
class FilePaths:
packages_file: str
stack_schema_map_file: str
deprecated_rules_file: Optional[str] = None
version_lock_file: Optional[str] = None
@dataclass
class TestConfigPath:
config: str
files: FilePaths
rule_dir: List[str]
testing: Optional[TestConfigPath] = None
@classmethod
def from_dict(cls, obj: dict) -> 'ConfigFile':
files_data = obj.get('files', {})
files = cls.FilePaths(
deprecated_rules_file=files_data.get('deprecated_rules'),
packages_file=files_data['packages'],
stack_schema_map_file=files_data['stack_schema_map'],
version_lock_file=files_data.get('version_lock')
)
rule_dir = obj['rule_dirs']
testing_data = obj.get('testing')
testing = cls.TestConfigPath(
config=testing_data['config']
) if testing_data else None
return cls(files=files, rule_dir=rule_dir, testing=testing)
@dataclass
class TestConfig:
"""Detection rules test config file"""
test_file: Optional[Path] = None
unit_tests: Optional[UnitTest] = None
rule_validation: Optional[RuleValidation] = None
@classmethod
def from_dict(cls, test_file: Optional[Path] = None, unit_tests: Optional[dict] = None,
rule_validation: Optional[dict] = None) -> 'TestConfig':
return cls(test_file=test_file or None, unit_tests=UnitTest(**unit_tests or {}),
rule_validation=RuleValidation(**rule_validation or {}))
@cached_property
def all_tests(self):
"""Get the list of all test names."""
return discover_tests()
def tests_by_patterns(self, *patterns: str) -> List[str]:
"""Get the list of test names by patterns."""
tests = set()
for pattern in patterns:
tests.update(list(fnmatch.filter(self.all_tests, pattern)))
return sorted(tests)
@staticmethod
def parse_out_patterns(names: List[str]) -> (List[str], List[str]):
"""Parse out test patterns from a list of test names."""
patterns = []
tests = []
for name in names:
if name.startswith('pattern:') and '*' in name:
patterns.append(name[len('pattern:'):])
else:
tests.append(name)
return patterns, tests
@staticmethod
def format_tests(tests: List[str]) -> List[str]:
"""Format unit test names into expected format for direct calling."""
raw = [t.rsplit('.', maxsplit=2) for t in tests]
formatted = []
for test in raw:
path, clazz, method = test
path = f'{path.replace(".", os.path.sep)}.py'
formatted.append('::'.join([path, clazz, method]))
return formatted
def get_test_names(self, formatted: bool = False) -> (List[str], List[str]):
"""Get the list of test names to run."""
patterns_t, tests_t = self.parse_out_patterns(self.unit_tests.test_only or [])
patterns_b, tests_b = self.parse_out_patterns(self.unit_tests.bypass or [])
defined_tests = tests_t + tests_b
patterns = patterns_t + patterns_b
unknowns = sorted(set(defined_tests) - set(self.all_tests))
assert not unknowns, f'Unrecognized test names in config ({self.test_file}): {unknowns}'
combined_tests = sorted(set(defined_tests + self.tests_by_patterns(*patterns)))
if self.unit_tests.test_only is not None:
tests = combined_tests
skipped = [t for t in self.all_tests if t not in tests]
elif self.unit_tests.bypass:
tests = []
skipped = []
for test in self.all_tests:
if test not in combined_tests:
tests.append(test)
else:
skipped.append(test)
else:
tests = self.all_tests
skipped = []
if formatted:
return self.format_tests(tests), self.format_tests(skipped)
else:
return tests, skipped
def check_skip_by_rule_id(self, rule_id: str) -> bool:
"""Check if a rule_id should be skipped."""
bypass = self.rule_validation.bypass
test_only = self.rule_validation.test_only
# neither bypass nor test_only are defined, so no rules are skipped
if not (bypass or test_only):
return False
# if defined in bypass or not defined in test_only, then skip
return (bypass and rule_id in bypass) or (test_only and rule_id not in test_only)
@dataclass
class RulesConfig:
"""Detection rules config file."""
deprecated_rules_file: Path
deprecated_rules: Dict[str, dict]
packages_file: Path
packages: Dict[str, dict]
rule_dirs: List[Path]
stack_schema_map_file: Path
stack_schema_map: Dict[str, dict]
test_config: TestConfig
version_lock_file: Path
version_lock: Dict[str, dict]
action_dir: Optional[Path] = None
action_connector_dir: Optional[Path] = None
auto_gen_schema_file: Optional[Path] = None
bbr_rules_dirs: Optional[List[Path]] = field(default_factory=list)
bypass_version_lock: bool = False
exception_dir: Optional[Path] = None
normalize_kql_keywords: bool = True
bypass_optional_elastic_validation: bool = False
def __post_init__(self):
"""Perform post validation on packages.yaml file."""
if 'package' not in self.packages:
raise ValueError('Missing the `package` field defined in packages.yaml.')
if 'name' not in self.packages['package']:
raise ValueError('Missing the `name` field defined in packages.yaml.')
@cached
def parse_rules_config(path: Optional[Path] = None) -> RulesConfig:
"""Parse the _config.yaml file for default or custom rules."""
if path:
assert path.exists(), f'rules config file does not exist: {path}'
loaded = yaml.safe_load(path.read_text())
elif CUSTOM_RULES_DIR:
path = Path(CUSTOM_RULES_DIR) / '_config.yaml'
if not path.exists():
raise FileNotFoundError(
"""
Configuration file not found.
Please create a configuration file. You can use the 'custom-rules setup-config' command
and update the 'CUSTOM_RULES_DIR' environment variable as needed.
"""
)
loaded = yaml.safe_load(path.read_text())
else:
path = Path(get_etc_path('_config.yaml'))
loaded = load_etc_dump('_config.yaml')
try:
ConfigFile.from_dict(loaded)
except KeyError as e:
raise SystemExit(f'Missing key `{str(e)}` in _config.yaml file.')
except (AttributeError, TypeError):
raise SystemExit(f'No data properly loaded from {path}')
except ValueError as e:
raise SystemExit(e)
base_dir = path.resolve().parent
# testing
# precedence to the environment variable
# environment variable is absolute path and config file is relative to the _config.yaml file
test_config_ev = os.getenv('DETECTION_RULES_TEST_CONFIG', None)
if test_config_ev:
test_config_path = Path(test_config_ev)
else:
test_config_file = loaded.get('testing', {}).get('config')
if test_config_file:
test_config_path = base_dir.joinpath(test_config_file)
else:
test_config_path = None
if test_config_path:
test_config_data = yaml.safe_load(test_config_path.read_text())
# overwrite None with empty list to allow implicit exemption of all tests with `test_only` defined to None in
# test config
if 'unit_tests' in test_config_data and test_config_data['unit_tests'] is not None:
test_config_data['unit_tests'] = {k: v or [] for k, v in test_config_data['unit_tests'].items()}
test_config = TestConfig.from_dict(test_file=test_config_path, **test_config_data)
else:
test_config = TestConfig.from_dict()
# files
# paths are relative
files = {f'{k}_file': base_dir.joinpath(v) for k, v in loaded['files'].items()}
contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded['files'].items()}
contents.update(**files)
# directories
# paths are relative
if loaded.get('directories'):
contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded['directories'].items()})
# rule_dirs
# paths are relative
contents['rule_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('rule_dirs')]
# directories
# paths are relative
if loaded.get('directories'):
directories = loaded.get('directories')
if directories.get('exception_dir'):
contents['exception_dir'] = base_dir.joinpath(directories.get('exception_dir')).resolve()
if directories.get('action_dir'):
contents['action_dir'] = base_dir.joinpath(directories.get('action_dir')).resolve()
if directories.get('action_connector_dir'):
contents['action_connector_dir'] = base_dir.joinpath(directories.get('action_connector_dir')).resolve()
# version strategy
contents['bypass_version_lock'] = loaded.get('bypass_version_lock', False)
# bbr_rules_dirs
# paths are relative
if loaded.get('bbr_rules_dirs'):
contents['bbr_rules_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('bbr_rules_dirs', [])]
# kql keyword normalization
contents['normalize_kql_keywords'] = loaded.get('normalize_kql_keywords', True)
if loaded.get('auto_gen_schema_file'):
contents['auto_gen_schema_file'] = base_dir.joinpath(loaded['auto_gen_schema_file'])
# Check if the file exists
if not contents['auto_gen_schema_file'].exists():
# If the file doesn't exist, create the necessary directories and file
contents['auto_gen_schema_file'].parent.mkdir(parents=True, exist_ok=True)
contents['auto_gen_schema_file'].write_text('{}')
# bypass_optional_elastic_validation
contents['bypass_optional_elastic_validation'] = loaded.get('bypass_optional_elastic_validation', False)
if contents['bypass_optional_elastic_validation']:
set_all_validation_bypass(contents['bypass_optional_elastic_validation'])
try:
rules_config = RulesConfig(test_config=test_config, **contents)
except (ValueError, TypeError) as e:
raise SystemExit(f'Error parsing packages.yaml: {str(e)}')
return rules_config
@cached
def load_current_package_version() -> str:
"""Load the current package version from config file."""
return parse_rules_config().packages['package']['name']