# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

"""Misc support."""
import os
import re
import time
import unittest
import uuid
from pathlib import Path
from functools import wraps
from typing import NoReturn, Optional

import click
import requests


# this is primarily for type hinting - all use of the github client should come from GithubClient class
try:
    from github import Github
    from github.Repository import Repository
    from github.GitRelease import GitRelease
    from github.GitReleaseAsset import GitReleaseAsset
except ImportError:
    # for type hinting
    Github = None  # noqa: N806
    Repository = None  # noqa: N806
    GitRelease = None  # noqa: N806
    GitReleaseAsset = None  # noqa: N806

from .utils import add_params, cached, get_path, load_etc_dump

_CONFIG = {}

LICENSE_HEADER = """
Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
or more contributor license agreements. Licensed under the Elastic License
2.0; you may not use this file except in compliance with the Elastic License
2.0.
""".strip()

LICENSE_LINES = LICENSE_HEADER.splitlines()
PYTHON_LICENSE = "\n".join("# " + line for line in LICENSE_LINES)
JS_LICENSE = """
/*
{}
 */
""".strip().format("\n".join(' * ' + line for line in LICENSE_LINES))


ROOT_DIR = Path(__file__).parent.parent


class ClientError(click.ClickException):
    """Custom CLI error to format output or full debug stacktrace."""

    def __init__(self, message, original_error=None):
        super(ClientError, self).__init__(message)
        self.original_error = original_error
        self.original_error_type = type(original_error).__name__ if original_error else ''

    def show(self, file=None, err=True):
        """Print the error to the console."""
        # err_msg = f' {self.original_error_type}' if self.original_error else ''
        msg = f'{click.style(f"CLI Error ({self.original_error_type})", fg="red", bold=True)}: {self.format_message()}'
        click.echo(msg, err=err, file=file)


def client_error(message, exc: Exception = None, debug=None, ctx: click.Context = None, file=None,
                 err=None) -> NoReturn:
    config_debug = True if ctx and ctx.ensure_object(dict) and ctx.obj.get('debug') is True else False
    debug = debug if debug is not None else config_debug

    if debug:
        click.echo(click.style('DEBUG: ', fg='yellow') + message, err=err, file=file)
        raise
    else:
        raise ClientError(message, original_error=exc)


def nested_get(_dict, dot_key, default=None):
    """Get a nested field from a nested dict with dot notation."""
    if _dict is None or dot_key is None:
        return default
    elif '.' in dot_key and isinstance(_dict, dict):
        dot_key = dot_key.split('.')
        this_key = dot_key.pop(0)
        return nested_get(_dict.get(this_key, default), '.'.join(dot_key), default)
    else:
        return _dict.get(dot_key, default)


def nested_set(_dict, dot_key, value):
    """Set a nested field from a key in dot notation."""
    keys = dot_key.split('.')
    for key in keys[:-1]:
        _dict = _dict.setdefault(key, {})

    if isinstance(_dict, dict):
        _dict[keys[-1]] = value
    else:
        raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key))


def nest_from_dot(dots, value):
    """Nest a dotted field and set the innermost value."""
    fields = dots.split('.')

    if not fields:
        return {}

    nested = {fields.pop(): value}

    for field_ in reversed(fields):
        nested = {field_: nested}

    return nested


def schema_prompt(name, value=None, is_required=False, **options):
    """Interactively prompt based on schema requirements."""
    name = str(name)
    field_type = options.get('type')
    pattern = options.get('pattern')
    enum = options.get('enum', [])
    minimum = options.get('minimum')
    maximum = options.get('maximum')
    min_item = options.get('min_items', 0)
    max_items = options.get('max_items', 9999)

    default = options.get('default')
    if default is not None and str(default).lower() in ('true', 'false'):
        default = str(default).lower()

    if 'date' in name:
        default = time.strftime('%Y/%m/%d')

    if name == 'rule_id':
        default = str(uuid.uuid4())

    if len(enum) == 1 and is_required and field_type != "array":
        return enum[0]

    def _check_type(_val):
        if field_type in ('number', 'integer') and not str(_val).isdigit():
            print('Number expected but got: {}'.format(_val))
            return False
        if pattern and (not re.match(pattern, _val) or len(re.match(pattern, _val).group(0)) != len(_val)):
            print('{} did not match pattern: {}!'.format(_val, pattern))
            return False
        if enum and _val not in enum:
            print('{} not in valid options: {}'.format(_val, ', '.join(enum)))
            return False
        if minimum and (type(_val) is int and int(_val) < minimum):
            print('{} is less than the minimum: {}'.format(str(_val), str(minimum)))
            return False
        if maximum and (type(_val) is int and int(_val) > maximum):
            print('{} is greater than the maximum: {}'.format(str(_val), str(maximum)))
            return False
        if field_type == 'boolean' and _val.lower() not in ('true', 'false'):
            print('Boolean expected but got: {}'.format(str(_val)))
            return False
        return True

    def _convert_type(_val):
        if field_type == 'boolean' and not type(_val) is bool:
            _val = True if _val.lower() == 'true' else False
        return int(_val) if field_type in ('number', 'integer') else _val

    prompt = '{name}{default}{required}{multi}'.format(
        name=name,
        default=' [{}] ("n/a" to leave blank) '.format(default) if default else '',
        required=' (required) ' if is_required else '',
        multi=' (multi, comma separated) ' if field_type == 'array' else '').strip() + ': '

    while True:
        result = value or input(prompt) or default
        if result == 'n/a':
            result = None

        if not result:
            if is_required:
                value = None
                continue
            else:
                return

        if field_type == 'array':
            result_list = result.split(',')

            if not (min_item < len(result_list) < max_items):
                if is_required:
                    value = None
                    break
                else:
                    return []

            for value in result_list:
                if not _check_type(value):
                    if is_required:
                        value = None
                        break
                    else:
                        return []
            if is_required and value is None:
                continue
            else:
                return [_convert_type(r) for r in result_list]
        else:
            if _check_type(result):
                return _convert_type(result)
            elif is_required:
                value = None
                continue
            return


def get_kibana_rules_map(repo='elastic/kibana', branch='master'):
    """Get list of available rules from the Kibana repo and return a list of URLs."""
    # ensure branch exists
    r = requests.get(f'https://api.github.com/repos/{repo}/branches/{branch}')
    r.raise_for_status()

    url = ('https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/'
           'detection_engine/rules/prepackaged_rules?ref={branch}')

    gh_rules = requests.get(url.format(legacy='', app='security_solution', branch=branch, repo=repo)).json()

    # pre-7.9 app was siem
    if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
        gh_rules = requests.get(url.format(legacy='', app='siem', branch=branch, repo=repo)).json()

    # pre-7.8 the siem was under the legacy directory
    if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
        gh_rules = requests.get(url.format(legacy='legacy/', app='siem', branch=branch, repo=repo)).json()

    if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found':
        raise ValueError(f'rules directory does not exist for {repo} branch: {branch}')

    return {os.path.splitext(r['name'])[0]: r['download_url'] for r in gh_rules if r['name'].endswith('.json')}


def get_kibana_rules(*rule_paths, repo='elastic/kibana', branch='master', verbose=True, threads=50):
    """Retrieve prepackaged rules from kibana repo."""
    from multiprocessing.pool import ThreadPool

    kibana_rules = {}

    if verbose:
        thread_use = f' using {threads} threads' if threads > 1 else ''
        click.echo(f'Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...')

    rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in rule_paths]
    rules_mapping = [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] \
        if rule_paths else get_kibana_rules_map(repo=repo, branch=branch).items()

    def download_worker(rule_info):
        n, u = rule_info
        kibana_rules[n] = requests.get(u).json()

    pool = ThreadPool(processes=threads)
    pool.map(download_worker, rules_mapping)
    pool.close()
    pool.join()

    return kibana_rules


@cached
def load_current_package_version() -> str:
    """Load the current package version from config file."""
    return load_etc_dump('packages.yaml')['package']['name']


def get_default_config() -> Optional[Path]:
    return next(get_path().glob('.detection-rules-cfg.*'), None)


@cached
def parse_user_config():
    """Parse a default config file."""
    import eql

    config_file = get_default_config()
    config = {}

    if config_file and config_file.exists():
        config = eql.utils.load_dump(str(config_file))

        click.secho(f'Loaded config file: {config_file}', fg='yellow')

    return config


def discover_tests(start_dir: str = 'tests', pattern: str = 'test*.py', top_level_dir: Optional[str] = None):
    """Discover all unit tests in a directory."""
    def list_tests(s, tests=None):
        if tests is None:
            tests = []
        for test in s:
            if isinstance(test, unittest.TestSuite):
                list_tests(test, tests)
            else:
                tests.append(test.id())
        return tests

    loader = unittest.defaultTestLoader
    suite = loader.discover(start_dir, pattern=pattern, top_level_dir=top_level_dir or str(ROOT_DIR))
    return list_tests(suite)


def getdefault(name):
    """Callback function for `default` to get an environment variable."""
    envvar = f"DR_{name.upper()}"
    config = parse_user_config()
    return lambda: os.environ.get(envvar, config.get(name))


def get_elasticsearch_client(cloud_id: str = None, elasticsearch_url: str = None, es_user: str = None,
                             es_password: str = None, ctx: click.Context = None, api_key: str = None, **kwargs):
    """Get an authenticated elasticsearch client."""
    from elasticsearch import AuthenticationException, Elasticsearch

    if not (cloud_id or elasticsearch_url):
        client_error("Missing required --cloud-id or --elasticsearch-url")

    # don't prompt for these until there's a cloud id or elasticsearch URL
    basic_auth: (str, str) | None = None
    if not api_key:
        es_user = es_user or click.prompt("es_user")
        es_password = es_password or click.prompt("es_password", hide_input=True)
        basic_auth = (es_user, es_password)

    hosts = [elasticsearch_url] if elasticsearch_url else None
    timeout = kwargs.pop('timeout', 60)
    kwargs['verify_certs'] = not kwargs.pop('ignore_ssl_errors', False)

    try:
        client = Elasticsearch(hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key,
                               **kwargs)
        # force login to test auth
        client.info()
        return client
    except AuthenticationException as e:
        error_msg = f'Failed authentication for {elasticsearch_url or cloud_id}'
        client_error(error_msg, e, ctx=ctx, err=True)


def get_kibana_client(cloud_id: str, kibana_url: str, kibana_user: str, kibana_password: str, kibana_cookie: str,
                      space: str, ignore_ssl_errors: bool, provider_type: str, provider_name: str, api_key: str,
                      **kwargs):
    """Get an authenticated Kibana client."""
    from requests import HTTPError
    from kibana import Kibana

    if not (cloud_id or kibana_url):
        client_error("Missing required --cloud-id or --kibana-url")

    if not (kibana_cookie or api_key):
        # don't prompt for these until there's a cloud id or Kibana URL
        kibana_user = kibana_user or click.prompt("kibana_user")
        kibana_password = kibana_password or click.prompt("kibana_password", hide_input=True)

    verify = not ignore_ssl_errors

    with Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, **kwargs) as kibana:
        if kibana_cookie:
            kibana.add_cookie(kibana_cookie)
            return kibana
        elif api_key:
            kibana.add_api_key(api_key)
            return kibana

        try:
            kibana.login(kibana_user, kibana_password, provider_type=provider_type, provider_name=provider_name)
        except HTTPError as exc:
            if exc.response.status_code == 401:
                err_msg = f'Authentication failed for {kibana_url}. If credentials are valid, check --provider-name'
                client_error(err_msg, exc, err=True)
            else:
                raise

        return kibana


client_options = {
    'kibana': {
        'cloud_id': click.Option(['--cloud-id'], default=getdefault('cloud_id'),
                                 help="ID of the cloud instance."),
        'api_key': click.Option(['--api-key'], default=getdefault('api_key')),
        'kibana_cookie': click.Option(['--kibana-cookie', '-kc'], default=getdefault('kibana_cookie'),
                                      help='Cookie from an authed session'),
        'kibana_password': click.Option(['--kibana-password', '-kp'], default=getdefault('kibana_password')),
        'kibana_url': click.Option(['--kibana-url'], default=getdefault('kibana_url')),
        'kibana_user': click.Option(['--kibana-user', '-ku'], default=getdefault('kibana_user')),
        'provider_type': click.Option(['--provider-type'], default=getdefault('provider_type'),
                                      help="Elastic Cloud providers: basic and saml (for SSO)"),
        'provider_name': click.Option(['--provider-name'], default=getdefault('provider_name'),
                                      help="Elastic Cloud providers: cloud-basic and cloud-saml (for SSO)"),
        'space': click.Option(['--space'], default=None, help='Kibana space'),
        'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors'))
    },
    'elasticsearch': {
        'cloud_id': click.Option(['--cloud-id'], default=getdefault("cloud_id")),
        'api_key': click.Option(['--api-key'], default=getdefault('api_key')),
        'elasticsearch_url': click.Option(['--elasticsearch-url'], default=getdefault("elasticsearch_url")),
        'es_user': click.Option(['--es-user', '-eu'], default=getdefault("es_user")),
        'es_password': click.Option(['--es-password', '-ep'], default=getdefault("es_password")),
        'timeout': click.Option(['--timeout', '-et'], default=60, help='Timeout for elasticsearch client'),
        'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors'))
    }
}
kibana_options = list(client_options['kibana'].values())
elasticsearch_options = list(client_options['elasticsearch'].values())


def add_client(*client_type, add_to_ctx=True, add_func_arg=True):
    """Wrapper to add authed client."""
    from elasticsearch import Elasticsearch
    from elasticsearch.exceptions import AuthenticationException
    from kibana import Kibana

    def _wrapper(func):
        client_ops_dict = {}
        client_ops_keys = {}
        for c_type in client_type:
            ops = client_options.get(c_type)
            client_ops_dict.update(ops)
            client_ops_keys[c_type] = list(ops)

        if not client_ops_dict:
            raise ValueError(f'Unknown client: {client_type} in {func.__name__}')

        client_ops = list(client_ops_dict.values())

        @wraps(func)
        @add_params(*client_ops)
        def _wrapped(*args, **kwargs):
            ctx: click.Context = next((a for a in args if isinstance(a, click.Context)), None)
            es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get('elasticsearch', [])}
            #                                      shared args like cloud_id
            kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get('kibana', [])}

            if 'elasticsearch' in client_type:
                # for nested ctx invocation, no need to re-auth if an existing client is already passed
                elasticsearch_client: Elasticsearch = kwargs.get('elasticsearch_client')
                try:
                    if elasticsearch_client and isinstance(elasticsearch_client, Elasticsearch) and \
                            elasticsearch_client.info():
                        pass
                    else:
                        elasticsearch_client = get_elasticsearch_client(**es_client_args)
                except AuthenticationException:
                    elasticsearch_client = get_elasticsearch_client(**es_client_args)

                if add_func_arg:
                    kwargs['elasticsearch_client'] = elasticsearch_client
                if ctx and add_to_ctx:
                    ctx.obj['es'] = elasticsearch_client

            if 'kibana' in client_type:
                # for nested ctx invocation, no need to re-auth if an existing client is already passed
                kibana_client: Kibana = kwargs.get('kibana_client')
                if kibana_client and isinstance(kibana_client, Kibana):

                    try:
                        with kibana_client:
                            if kibana_client.version:
                                pass  # kibana_client is valid and can be used directly
                    except (requests.HTTPError, AttributeError):
                        kibana_client = get_kibana_client(**kibana_client_args)
                else:
                    # Instantiate a new Kibana client if none was provided or if the provided one is not usable
                    kibana_client = get_kibana_client(**kibana_client_args)

                if add_func_arg:
                    kwargs['kibana_client'] = kibana_client
                if ctx and add_to_ctx:
                    ctx.obj['kibana'] = kibana_client

            return func(*args, **kwargs)

        return _wrapped

    return _wrapper
