# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

"""Elasticsearch cli commands."""
import json
import os
import sys
import time
from collections import defaultdict
from typing import List, Union

import click
import elasticsearch
from elasticsearch import Elasticsearch
from elasticsearch.client import AsyncSearchClient

import kql
from .config import parse_rules_config
from .main import root
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get
from .rule import TOMLRule

from .rule_loader import RuleCollection
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path

COLLECTION_DIR = get_path('collections')
MATCH_ALL = {'bool': {'filter': [{'match_all': {}}]}}
RULES_CONFIG = parse_rules_config()


def add_range_to_dsl(dsl_filter, start_time, end_time='now'):
    dsl_filter.append(
        {"range": {"@timestamp": {"gt": start_time, "lte": end_time, "format": "strict_date_optional_time"}}}
    )


def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict):
    parsed_results = defaultdict(lambda: defaultdict(int))
    hits = search_results['hits']
    hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
    for hit in hits:
        for field in unique_fields:
            if 'events' in hit:
                match = []
                for event in hit['events']:
                    matched = nested_get(event['_source'], field)
                    match.extend([matched] if not isinstance(matched, list) else matched)
                    if not match:
                        continue
            else:
                match = nested_get(hit['_source'], field)
                if not match:
                    continue

            match = ','.join(sorted(match)) if isinstance(match, list) else match
            parsed_results[field][match] += 1
    # if rule.type == eql, structure is different
    return {'results': parsed_results} if parsed_results else {}


class Events:
    """Events collected from Elasticsearch."""

    def __init__(self, events):
        self.events: dict = self._normalize_event_timing(events)

    @staticmethod
    def _normalize_event_timing(events):
        """Normalize event timestamps and sort."""
        for agent_type, _events in events.items():
            events[agent_type] = normalize_timing_and_sort(_events)

        return events

    @staticmethod
    def _get_dump_dir(rta_name=None, host_id=None, host_os_family=None):
        """Prepare and get the dump path."""
        if rta_name and host_os_family:
            dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name, host_os_family)
            os.makedirs(dump_dir, exist_ok=True)
            return dump_dir
        else:
            time_str = time.strftime('%Y%m%dT%H%M%SL')
            dump_dir = os.path.join(COLLECTION_DIR, host_id or 'unknown_host', time_str)
            os.makedirs(dump_dir, exist_ok=True)
            return dump_dir

    def evaluate_against_rule(self, rule_id, verbose=True):
        """Evaluate a rule against collected events and update mapping."""
        from .utils import combine_sources, evaluate

        rule = RuleCollection.default().id_map.get(rule_id)
        assert rule is not None, f"Unable to find rule with ID {rule_id}"
        merged_events = combine_sources(*self.events.values())
        filtered = evaluate(rule, merged_events, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords)

        if verbose:
            click.echo('Matching results found')

        return filtered

    def echo_events(self, pager=False, pretty=True):
        """Print events to stdout."""
        echo_fn = click.echo_via_pager if pager else click.echo
        echo_fn(json.dumps(self.events, indent=2 if pretty else None, sort_keys=True))

    def save(self, rta_name=None, dump_dir=None, host_id=None):
        """Save collected events."""
        assert self.events, 'Nothing to save. Run Collector.run() method first or verify logging'

        host_os_family = None
        for key in self.events.keys():
            if self.events.get(key, {})[0].get('host', {}).get('id') == host_id:
                host_os_family = self.events.get(key, {})[0].get('host', {}).get('os').get('family')
                break
        if not host_os_family:
            click.echo('Unable to determine host.os.family for host_id: {}'.format(host_id))
            host_os_family = click.prompt("Please enter the host.os.family for this host_id",
                                          type=click.Choice(["windows", "macos", "linux"]), default="windows")

        dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id, host_os_family=host_os_family)

        for source, events in self.events.items():
            path = os.path.join(dump_dir, source + '.ndjson')
            with open(path, 'w') as f:
                f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events])
                click.echo('{} events saved to: {}'.format(len(events), path))


class CollectEvents(object):
    """Event collector for elastic stack."""

    def __init__(self, client, max_events=3000):
        self.client: Elasticsearch = client
        self.max_events = max_events

    def _build_timestamp_map(self, index_str):
        """Build a mapping of indexes to timestamp data formats."""
        mappings = self.client.indices.get_mapping(index=index_str)
        timestamp_map = {n: m['mappings'].get('properties', {}).get('@timestamp', {}) for n, m in mappings.items()}
        return timestamp_map

    def _get_last_event_time(self, index_str, dsl=None):
        """Get timestamp of most recent event."""
        last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits']
        if not last_event:
            return

        last_event = last_event[0]
        index = last_event['_index']
        timestamp = last_event['_source']['@timestamp']

        timestamp_map = self._build_timestamp_map(index_str)
        event_date_format = timestamp_map[index].get('format', '').split('||')

        # there are many native supported date formats and even custom data formats, but most, including beats use the
        # default `strict_date_optional_time`. It would be difficult to try to account for all possible formats, so this
        # will work on the default and unix time.
        if set(event_date_format) & {'epoch_millis', 'epoch_second'}:
            timestamp = unix_time_to_formatted(timestamp)

        return timestamp

    @staticmethod
    def _prep_query(query, language, index, start_time=None, end_time=None):
        """Prep a query for search."""
        index_str = ','.join(index if isinstance(index, (list, tuple)) else index.split(','))
        lucene_query = query if language == 'lucene' else None

        if language in ('kql', 'kuery'):
            formatted_dsl = {'query': kql.to_dsl(query)}
        elif language == 'eql':
            formatted_dsl = {'query': query, 'filter': MATCH_ALL}
        elif language == 'lucene':
            formatted_dsl = {'query': {'bool': {'filter': []}}}
        elif language == 'dsl':
            formatted_dsl = {'query': query}
        else:
            raise ValueError(f'Unknown search language: {language}')

        if start_time or end_time:
            end_time = end_time or 'now'
            dsl = formatted_dsl['filter']['bool']['filter'] if language == 'eql' else \
                formatted_dsl['query']['bool'].setdefault('filter', [])
            add_range_to_dsl(dsl, start_time, end_time)

        return index_str, formatted_dsl, lucene_query

    def search(self, query, language, index: Union[str, list] = '*', start_time=None, end_time=None, size=None,
               **kwargs):
        """Search an elasticsearch instance."""
        index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index,
                                                                  start_time=start_time, end_time=end_time)
        formatted_dsl.update(size=size or self.max_events)

        if language == 'eql':
            results = self.client.eql.search(body=formatted_dsl, index=index_str, **kwargs)['hits']
            results = results.get('events') or results.get('sequences', [])
        else:
            results = self.client.search(body=formatted_dsl, q=lucene_query, index=index_str,
                                         allow_no_indices=True, ignore_unavailable=True, **kwargs)['hits']['hits']

        return results

    def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None):
        """Search an elasticsearch instance using a rule."""
        async_client = AsyncSearchClient(self.client)
        survey_results = {}
        multi_search = []
        multi_search_rules = []
        async_searches = []
        eql_searches = []

        for rule in rules:
            if not rule.contents.data.get('query'):
                continue

            language = rule.contents.data.get('language')
            query = rule.contents.data.query
            rule_type = rule.contents.data.type
            index_str, formatted_dsl, lucene_query = self._prep_query(query=query,
                                                                      language=language,
                                                                      index=rule.contents.data.get('index', '*'),
                                                                      start_time=start_time,
                                                                      end_time=end_time)
            formatted_dsl.update(size=size or self.max_events)

            # prep for searches: msearch for kql | async search for lucene | eql client search for eql
            if language == 'kuery':
                multi_search_rules.append(rule)
                multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'})
                multi_search.append(formatted_dsl)
            elif language == 'lucene':
                # wait for 0 to try and force async with no immediate results (not guaranteed)
                result = async_client.submit(body=formatted_dsl, q=query, index=index_str,
                                             allow_no_indices=True, ignore_unavailable=True,
                                             wait_for_completion_timeout=0)
                if result['is_running'] is True:
                    async_searches.append((rule, result['id']))
                else:
                    survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'],
                                                                         result['response'])
            elif language == 'eql':
                eql_body = {
                    'index': index_str,
                    'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'},
                    'body': {'query': query, 'filter': formatted_dsl['filter']}
                }
                eql_searches.append((rule, eql_body))

        # assemble search results
        multi_search_results = self.client.msearch(searches=multi_search)
        for index, result in enumerate(multi_search_results['responses']):
            try:
                rule = multi_search_rules[index]
                survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
                                                                     rule.contents.data.unique_fields, result)
            except KeyError:
                survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True}

        for entry in eql_searches:
            rule: TOMLRule
            search_args: dict
            rule, search_args = entry
            try:
                result = self.client.eql.search(**search_args)
                survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
                                                                     rule.contents.data.unique_fields, result)
            except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e:
                survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']}

        for entry in async_searches:
            rule: TOMLRule
            rule, async_id = entry
            result = async_client.get(id=async_id)['response']
            survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result)

        return survey_results

    def count(self, query, language, index: Union[str, list], start_time=None, end_time='now'):
        """Get a count of documents from elasticsearch."""
        index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index,
                                                                  start_time=start_time, end_time=end_time)

        # EQL API has no count endpoint
        if language == 'eql':
            results = self.search(query=query, language=language, index=index, start_time=start_time, end_time=end_time,
                                  size=1000)
            return len(results)
        else:
            return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True,
                                     ignore_unavailable=True)['count']

    def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'):
        """Get a count of documents from elasticsearch using a rule."""
        survey_results = {}

        for rule in rules.rules:
            rule_results = {'rule_id': rule.id, 'name': rule.name}

            if not rule.contents.data.get('query'):
                continue

            try:
                rule_results['search_count'] = self.count(query=rule.contents.data.query,
                                                          language=rule.contents.data.language,
                                                          index=rule.contents.data.get('index', '*'),
                                                          start_time=start_time,
                                                          end_time=end_time)
            except (elasticsearch.NotFoundError, elasticsearch.RequestError):
                rule_results['search_count'] = -1

            survey_results[rule.id] = rule_results

        return survey_results


class CollectEventsWithDSL(CollectEvents):
    """Collect events from elasticsearch."""

    @staticmethod
    def _group_events_by_type(events):
        """Group events by agent.type."""
        event_by_type = {}

        for event in events:
            event_by_type.setdefault(event['_source']['agent']['type'], []).append(event['_source'])

        return event_by_type

    def run(self, dsl, indexes, start_time):
        """Collect the events."""
        results = self.search(dsl, language='dsl', index=indexes, start_time=start_time, end_time='now', size=5000,
                              sort=[{'@timestamp': {'order': 'asc'}}])
        events = self._group_events_by_type(results)
        return Events(events)


@root.command('normalize-data')
@click.argument('events-file', type=click.File('r'))
def normalize_data(events_file):
    """Normalize Elasticsearch data timestamps and sort."""
    file_name = os.path.splitext(os.path.basename(events_file.name))[0]
    events = Events({file_name: [json.loads(e) for e in events_file.readlines()]})
    events.save(dump_dir=os.path.dirname(events_file.name))


@root.group('es')
@add_params(*elasticsearch_options)
@click.pass_context
def es_group(ctx: click.Context, **kwargs):
    """Commands for integrating with Elasticsearch."""
    ctx.ensure_object(dict)

    # only initialize an es client if the subcommand is invoked without help (hacky)
    if sys.argv[-1] in ctx.help_option_names:
        click.echo('Elasticsearch client:')
        click.echo(format_command_options(ctx))

    else:
        ctx.obj['es'] = get_elasticsearch_client(ctx=ctx, **kwargs)


@es_group.command('collect-events')
@click.argument('host-id')
@click.option('--query', '-q', help='KQL query to scope search')
@click.option('--index', '-i', multiple=True, help='Index(es) to search against (default: all indexes)')
@click.option('--rta-name', '-r', help='Name of RTA in order to save events directly to unit tests data directory')
@click.option('--rule-id', help='Updates rule mapping in rule-mapping.yaml file (requires --rta-name)')
@click.option('--view-events', is_flag=True, help='Print events after saving')
@click.pass_context
def collect_events(ctx, host_id, query, index, rta_name, rule_id, view_events):
    """Collect events from Elasticsearch."""
    client: Elasticsearch = ctx.obj['es']
    dsl = kql.to_dsl(query) if query else MATCH_ALL
    dsl['bool'].setdefault('filter', []).append({'bool': {'should': [{'match_phrase': {'host.id': host_id}}]}})

    try:
        collector = CollectEventsWithDSL(client)
        start = time.time()
        click.pause('Press any key once detonation is complete ...')
        start_time = f'now-{round(time.time() - start) + 5}s'
        events = collector.run(dsl, index or '*', start_time)
        events.save(rta_name=rta_name, host_id=host_id)

        if rta_name and rule_id:
            events.evaluate_against_rule(rule_id)

        if view_events and events.events:
            events.echo_events(pager=True)

        return events
    except AssertionError as e:
        error_msg = 'No events collected! Verify events are streaming and that the agent-hostname is correct'
        client_error(error_msg, e, ctx=ctx)


@es_group.command('index-rules')
@click.option('--query', '-q', help='Optional KQL query to limit to specific rules')
@click.option('--from-file', '-f', type=click.File('r'), help='Load a previously saved uploadable bulk file')
@click.option('--save_files', '-s', is_flag=True, help='Optionally save the bulk request to a file')
@click.pass_context
def index_repo(ctx: click.Context, query, from_file, save_files):
    """Index rules based on KQL search results to an elasticsearch instance."""
    from .main import generate_rules_index

    es_client: Elasticsearch = ctx.obj['es']

    if from_file:
        bulk_upload_docs = from_file.read()

        # light validation only
        try:
            index_body = [json.loads(line) for line in bulk_upload_docs.splitlines()]
            click.echo(f'{len([r for r in index_body if "rule" in r])} rules included')
        except json.JSONDecodeError:
            client_error(f'Improperly formatted bulk request file: {from_file.name}')
    else:
        bulk_upload_docs, importable_rules_docs = ctx.invoke(generate_rules_index, query=query, save_files=save_files)

    es_client.bulk(bulk_upload_docs)
