detection_rules/misc.py (344 lines of code) (raw):
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Misc support."""
import os
import re
import time
import unittest
import uuid
from pathlib import Path
from functools import wraps
from typing import NoReturn, Optional
import click
import requests
# this is primarily for type hinting - all use of the github client should come from GithubClient class
try:
from github import Github
from github.Repository import Repository
from github.GitRelease import GitRelease
from github.GitReleaseAsset import GitReleaseAsset
except ImportError:
# for type hinting
Github = None # noqa: N806
Repository = None # noqa: N806
GitRelease = None # noqa: N806
GitReleaseAsset = None # noqa: N806
from .utils import add_params, cached, get_path, load_etc_dump
_CONFIG = {}
LICENSE_HEADER = """
Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
or more contributor license agreements. Licensed under the Elastic License
2.0; you may not use this file except in compliance with the Elastic License
2.0.
""".strip()
LICENSE_LINES = LICENSE_HEADER.splitlines()
PYTHON_LICENSE = "\n".join("# " + line for line in LICENSE_LINES)
JS_LICENSE = """
/*
{}
*/
""".strip().format("\n".join(' * ' + line for line in LICENSE_LINES))
ROOT_DIR = Path(__file__).parent.parent
class ClientError(click.ClickException):
"""Custom CLI error to format output or full debug stacktrace."""
def __init__(self, message, original_error=None):
super(ClientError, self).__init__(message)
self.original_error = original_error
self.original_error_type = type(original_error).__name__ if original_error else ''
def show(self, file=None, err=True):
"""Print the error to the console."""
# err_msg = f' {self.original_error_type}' if self.original_error else ''
msg = f'{click.style(f"CLI Error ({self.original_error_type})", fg="red", bold=True)}: {self.format_message()}'
click.echo(msg, err=err, file=file)
def client_error(message, exc: Exception = None, debug=None, ctx: click.Context = None, file=None,
err=None) -> NoReturn:
config_debug = True if ctx and ctx.ensure_object(dict) and ctx.obj.get('debug') is True else False
debug = debug if debug is not None else config_debug
if debug:
click.echo(click.style('DEBUG: ', fg='yellow') + message, err=err, file=file)
raise
else:
raise ClientError(message, original_error=exc)
def nested_get(_dict, dot_key, default=None):
"""Get a nested field from a nested dict with dot notation."""
if _dict is None or dot_key is None:
return default
elif '.' in dot_key and isinstance(_dict, dict):
dot_key = dot_key.split('.')
this_key = dot_key.pop(0)
return nested_get(_dict.get(this_key, default), '.'.join(dot_key), default)
else:
return _dict.get(dot_key, default)
def nested_set(_dict, dot_key, value):
"""Set a nested field from a key in dot notation."""
keys = dot_key.split('.')
for key in keys[:-1]:
_dict = _dict.setdefault(key, {})
if isinstance(_dict, dict):
_dict[keys[-1]] = value
else:
raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key))
def nest_from_dot(dots, value):
"""Nest a dotted field and set the innermost value."""
fields = dots.split('.')
if not fields:
return {}
nested = {fields.pop(): value}
for field_ in reversed(fields):
nested = {field_: nested}
return nested
def schema_prompt(name, value=None, is_required=False, **options):
"""Interactively prompt based on schema requirements."""
name = str(name)
field_type = options.get('type')
pattern = options.get('pattern')
enum = options.get('enum', [])
minimum = options.get('minimum')
maximum = options.get('maximum')
min_item = options.get('min_items', 0)
max_items = options.get('max_items', 9999)
default = options.get('default')
if default is not None and str(default).lower() in ('true', 'false'):
default = str(default).lower()
if 'date' in name:
default = time.strftime('%Y/%m/%d')
if name == 'rule_id':
default = str(uuid.uuid4())
if len(enum) == 1 and is_required and field_type != "array":
return enum[0]
def _check_type(_val):
if field_type in ('number', 'integer') and not str(_val).isdigit():
print('Number expected but got: {}'.format(_val))
return False
if pattern and (not re.match(pattern, _val) or len(re.match(pattern, _val).group(0)) != len(_val)):
print('{} did not match pattern: {}!'.format(_val, pattern))
return False
if enum and _val not in enum:
print('{} not in valid options: {}'.format(_val, ', '.join(enum)))
return False
if minimum and (type(_val) is int and int(_val) < minimum):
print('{} is less than the minimum: {}'.format(str(_val), str(minimum)))
return False
if maximum and (type(_val) is int and int(_val) > maximum):
print('{} is greater than the maximum: {}'.format(str(_val), str(maximum)))
return False
if field_type == 'boolean' and _val.lower() not in ('true', 'false'):
print('Boolean expected but got: {}'.format(str(_val)))
return False
return True
def _convert_type(_val):
if field_type == 'boolean' and not type(_val) is bool:
_val = True if _val.lower() == 'true' else False
return int(_val) if field_type in ('number', 'integer') else _val
prompt = '{name}{default}{required}{multi}'.format(
name=name,
default=' [{}] ("n/a" to leave blank) '.format(default) if default else '',
required=' (required) ' if is_required else '',
multi=' (multi, comma separated) ' if field_type == 'array' else '').strip() + ': '
while True:
result = value or input(prompt) or default
if result == 'n/a':
result = None
if not result:
if is_required:
value = None
continue
else:
return
if field_type == 'array':
result_list = result.split(',')
if not (min_item < len(result_list) < max_items):
if is_required:
value = None
break
else:
return []
for value in result_list:
if not _check_type(value):
if is_required:
value = None
break
else:
return []
if is_required and value is None:
continue
else:
return [_convert_type(r) for r in result_list]
else:
if _check_type(result):
return _convert_type(result)
elif is_required:
value = None
continue
return
def get_kibana_rules_map(repo='elastic/kibana', branch='master'):
"""Get list of available rules from the Kibana repo and return a list of URLs."""
# ensure branch exists
r = requests.get(f'https://api.github.com/repos/{repo}/branches/{branch}')
r.raise_for_status()
url = ('https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/'
'detection_engine/rules/prepackaged_rules?ref={branch}')
gh_rules = requests.get(url.format(legacy='', app='security_solution', branch=branch, repo=repo)).json()
# pre-7.9 app was siem
if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
gh_rules = requests.get(url.format(legacy='', app='siem', branch=branch, repo=repo)).json()
# pre-7.8 the siem was under the legacy directory
if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
gh_rules = requests.get(url.format(legacy='legacy/', app='siem', branch=branch, repo=repo)).json()
if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
raise ValueError(f'rules directory does not exist for {repo} branch: {branch}')
return {os.path.splitext(r['name'])[0]: r['download_url'] for r in gh_rules if r['name'].endswith('.json')}
def get_kibana_rules(*rule_paths, repo='elastic/kibana', branch='master', verbose=True, threads=50):
"""Retrieve prepackaged rules from kibana repo."""
from multiprocessing.pool import ThreadPool
kibana_rules = {}
if verbose:
thread_use = f' using {threads} threads' if threads > 1 else ''
click.echo(f'Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...')
rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in rule_paths]
rules_mapping = [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] \
if rule_paths else get_kibana_rules_map(repo=repo, branch=branch).items()
def download_worker(rule_info):
n, u = rule_info
kibana_rules[n] = requests.get(u).json()
pool = ThreadPool(processes=threads)
pool.map(download_worker, rules_mapping)
pool.close()
pool.join()
return kibana_rules
@cached
def load_current_package_version() -> str:
"""Load the current package version from config file."""
return load_etc_dump('packages.yaml')['package']['name']
def get_default_config() -> Optional[Path]:
return next(get_path().glob('.detection-rules-cfg.*'), None)
@cached
def parse_user_config():
"""Parse a default config file."""
import eql
config_file = get_default_config()
config = {}
if config_file and config_file.exists():
config = eql.utils.load_dump(str(config_file))
click.secho(f'Loaded config file: {config_file}', fg='yellow')
return config
def discover_tests(start_dir: str = 'tests', pattern: str = 'test*.py', top_level_dir: Optional[str] = None):
"""Discover all unit tests in a directory."""
def list_tests(s, tests=None):
if tests is None:
tests = []
for test in s:
if isinstance(test, unittest.TestSuite):
list_tests(test, tests)
else:
tests.append(test.id())
return tests
loader = unittest.defaultTestLoader
suite = loader.discover(start_dir, pattern=pattern, top_level_dir=top_level_dir or str(ROOT_DIR))
return list_tests(suite)
def getdefault(name):
"""Callback function for `default` to get an environment variable."""
envvar = f"DR_{name.upper()}"
config = parse_user_config()
return lambda: os.environ.get(envvar, config.get(name))
def get_elasticsearch_client(cloud_id: str = None, elasticsearch_url: str = None, es_user: str = None,
es_password: str = None, ctx: click.Context = None, api_key: str = None, **kwargs):
"""Get an authenticated elasticsearch client."""
from elasticsearch import AuthenticationException, Elasticsearch
if not (cloud_id or elasticsearch_url):
client_error("Missing required --cloud-id or --elasticsearch-url")
# don't prompt for these until there's a cloud id or elasticsearch URL
basic_auth: (str, str) | None = None
if not api_key:
es_user = es_user or click.prompt("es_user")
es_password = es_password or click.prompt("es_password", hide_input=True)
basic_auth = (es_user, es_password)
hosts = [elasticsearch_url] if elasticsearch_url else None
timeout = kwargs.pop('timeout', 60)
kwargs['verify_certs'] = not kwargs.pop('ignore_ssl_errors', False)
try:
client = Elasticsearch(hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key,
**kwargs)
# force login to test auth
client.info()
return client
except AuthenticationException as e:
error_msg = f'Failed authentication for {elasticsearch_url or cloud_id}'
client_error(error_msg, e, ctx=ctx, err=True)
def get_kibana_client(cloud_id: str, kibana_url: str, kibana_user: str, kibana_password: str, kibana_cookie: str,
space: str, ignore_ssl_errors: bool, provider_type: str, provider_name: str, api_key: str,
**kwargs):
"""Get an authenticated Kibana client."""
from requests import HTTPError
from kibana import Kibana
if not (cloud_id or kibana_url):
client_error("Missing required --cloud-id or --kibana-url")
if not (kibana_cookie or api_key):
# don't prompt for these until there's a cloud id or Kibana URL
kibana_user = kibana_user or click.prompt("kibana_user")
kibana_password = kibana_password or click.prompt("kibana_password", hide_input=True)
verify = not ignore_ssl_errors
with Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, **kwargs) as kibana:
if kibana_cookie:
kibana.add_cookie(kibana_cookie)
return kibana
elif api_key:
kibana.add_api_key(api_key)
return kibana
try:
kibana.login(kibana_user, kibana_password, provider_type=provider_type, provider_name=provider_name)
except HTTPError as exc:
if exc.response.status_code == 401:
err_msg = f'Authentication failed for {kibana_url}. If credentials are valid, check --provider-name'
client_error(err_msg, exc, err=True)
else:
raise
return kibana
client_options = {
'kibana': {
'cloud_id': click.Option(['--cloud-id'], default=getdefault('cloud_id'),
help="ID of the cloud instance."),
'api_key': click.Option(['--api-key'], default=getdefault('api_key')),
'kibana_cookie': click.Option(['--kibana-cookie', '-kc'], default=getdefault('kibana_cookie'),
help='Cookie from an authed session'),
'kibana_password': click.Option(['--kibana-password', '-kp'], default=getdefault('kibana_password')),
'kibana_url': click.Option(['--kibana-url'], default=getdefault('kibana_url')),
'kibana_user': click.Option(['--kibana-user', '-ku'], default=getdefault('kibana_user')),
'provider_type': click.Option(['--provider-type'], default=getdefault('provider_type'),
help="Elastic Cloud providers: basic and saml (for SSO)"),
'provider_name': click.Option(['--provider-name'], default=getdefault('provider_name'),
help="Elastic Cloud providers: cloud-basic and cloud-saml (for SSO)"),
'space': click.Option(['--space'], default=None, help='Kibana space'),
'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors'))
},
'elasticsearch': {
'cloud_id': click.Option(['--cloud-id'], default=getdefault("cloud_id")),
'api_key': click.Option(['--api-key'], default=getdefault('api_key')),
'elasticsearch_url': click.Option(['--elasticsearch-url'], default=getdefault("elasticsearch_url")),
'es_user': click.Option(['--es-user', '-eu'], default=getdefault("es_user")),
'es_password': click.Option(['--es-password', '-ep'], default=getdefault("es_password")),
'timeout': click.Option(['--timeout', '-et'], default=60, help='Timeout for elasticsearch client'),
'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors'))
}
}
kibana_options = list(client_options['kibana'].values())
elasticsearch_options = list(client_options['elasticsearch'].values())
def add_client(*client_type, add_to_ctx=True, add_func_arg=True):
"""Wrapper to add authed client."""
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import AuthenticationException
from kibana import Kibana
def _wrapper(func):
client_ops_dict = {}
client_ops_keys = {}
for c_type in client_type:
ops = client_options.get(c_type)
client_ops_dict.update(ops)
client_ops_keys[c_type] = list(ops)
if not client_ops_dict:
raise ValueError(f'Unknown client: {client_type} in {func.__name__}')
client_ops = list(client_ops_dict.values())
@wraps(func)
@add_params(*client_ops)
def _wrapped(*args, **kwargs):
ctx: click.Context = next((a for a in args if isinstance(a, click.Context)), None)
es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get('elasticsearch', [])}
# shared args like cloud_id
kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get('kibana', [])}
if 'elasticsearch' in client_type:
# for nested ctx invocation, no need to re-auth if an existing client is already passed
elasticsearch_client: Elasticsearch = kwargs.get('elasticsearch_client')
try:
if elasticsearch_client and isinstance(elasticsearch_client, Elasticsearch) and \
elasticsearch_client.info():
pass
else:
elasticsearch_client = get_elasticsearch_client(**es_client_args)
except AuthenticationException:
elasticsearch_client = get_elasticsearch_client(**es_client_args)
if add_func_arg:
kwargs['elasticsearch_client'] = elasticsearch_client
if ctx and add_to_ctx:
ctx.obj['es'] = elasticsearch_client
if 'kibana' in client_type:
# for nested ctx invocation, no need to re-auth if an existing client is already passed
kibana_client: Kibana = kwargs.get('kibana_client')
if kibana_client and isinstance(kibana_client, Kibana):
try:
with kibana_client:
if kibana_client.version:
pass # kibana_client is valid and can be used directly
except (requests.HTTPError, AttributeError):
kibana_client = get_kibana_client(**kibana_client_args)
else:
# Instantiate a new Kibana client if none was provided or if the provided one is not usable
kibana_client = get_kibana_client(**kibana_client_args)
if add_func_arg:
kwargs['kibana_client'] = kibana_client
if ctx and add_to_ctx:
ctx.obj['kibana'] = kibana_client
return func(*args, **kwargs)
return _wrapped
return _wrapper