#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Config command to setup the CLI and settings.
"""

import argparse
from dataclasses import is_dataclass, Field, fields, MISSING
import itertools
import json
import importlib
import logging
import typing as t


try:
    from typing import ForwardRef  # >= 3.7
except ImportError:
    # <3.7
    from typing import _ForwardRef as ForwardRef  # type: ignore


from threatexchange.extensions.manifest import ThreatExchangeExtensionManifest
from threatexchange import meta as tx_meta
from threatexchange import common
from threatexchange.cli.cli_config import CLISettings
from threatexchange.cli import command_base
from threatexchange.cli.exceptions import CommandError
from threatexchange.fetcher.apis.fb_threatexchange_api import (
    FBThreatExchangeSignalExchangeAPI,
)
from threatexchange.fetcher.fetch_api import SignalExchangeAPI
from threatexchange.fetcher.apis.static_sample import StaticSampleSignalExchangeAPI
from threatexchange.signal_type.signal_base import SignalType


class ConfigCollabListCommand(command_base.Command):
    """List collaborations"""

    @classmethod
    def get_name(cls) -> str:
        return "list"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        # Ideas:
        #  * Filter by API
        pass

    def execute(self, settings: CLISettings) -> None:
        for collab in settings.get_all_collabs():
            api = settings.get_api_for_collab(collab)
            print(collab.name, f"({api.get_name()})")


class _UpdateCollabCommand(command_base.Command):
    """
    Create or edit collaborations for this API

    Programatically generated by inspecting the config class, so not everything will be
    documented.
    """

    _API_CLS = SignalExchangeAPI

    _IGNORE_FIELDS = {
        "name",
        "api",
        "enabled",
        "only_signal_types",
        "not_signal_types",
        "only_owners",
        "not_owners",
        "only_tags",
        "not_tags",
    }

    @classmethod
    def get_name(cls) -> str:
        return cls._API_CLS.get_name()

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        cfg_cls = cls._API_CLS.get_config_class()
        assert is_dataclass(cfg_cls)

        ap.add_argument("collab_name", help="the name of the collab")
        ap.set_defaults(api_name=cls._API_CLS.get_name())
        on_off = ap.add_mutually_exclusive_group()
        ap.add_argument(
            "--create",
            "-C",
            action="store_true",
            help="indicate you intend to create a config",
        )
        # This goofy syntax allows --enable, --enable=1, and enable=0 to disable
        on_off.add_argument(
            "--enable",
            nargs="?",
            type=int,
            const=1,
            choices=[0, 1],
            help="enable the config (default on create)",
        )
        on_off.add_argument(
            "--disable",
            dest="enable",
            action="store_const",
            const=0,
            help="disable the config",
        )
        ap.add_argument(
            "--only-signal-types",
            "-s",
            nargs="*",
            type=common.argparse_choices_pre_type(
                [s.get_name() for s in settings.get_all_signal_types()],
                settings.get_signal_type,
            ),
            metavar="NAME",
            help="limit to these signal types",
        )
        ap.add_argument(
            "--not-signal-types",
            "-S",
            nargs="*",
            type=common.argparse_choices_pre_type(
                [s.get_name() for s in settings.get_all_signal_types()],
                settings.get_signal_type,
            ),
            metavar="NAME",
            help="dont use these signal types",
        )
        ap.add_argument(
            "--only-owners",
            "-o",
            nargs="*",
            type=int,
            metavar="ID",
            help="only use signals from these owner ids",
        )
        ap.add_argument(
            "--not-owners",
            "-O",
            nargs="*",
            type=int,
            metavar="ID",
            help="dont use signals from these owner ids",
        )
        ap.add_argument(
            "--only-tags",
            "-t",
            nargs="*",
            metavar="TAG",
            help="use only signals with one of these tags",
        )
        ap.add_argument(
            "--not-tags",
            "-T",
            nargs="*",
            metavar="TAG",
            help="don't use signals with one of these tags",
        )
        ap.add_argument(
            "--json",
            "-J",
            dest="is_json",
            action="store_true",
            help="instead, interpret the argument as JSON and use that to edit the config",
        )

        for field in fields(cfg_cls):
            cls._add_argument(ap, field)

    @classmethod
    def _add_argument(cls, ap: argparse.ArgumentParser, field: Field) -> None:
        if not field.init:
            return
        if field.name in cls._IGNORE_FIELDS:
            return
        assert not isinstance(
            field.type, ForwardRef
        ), "rework class to not have forward ref"

        target_type = field.type
        if hasattr(field.type, "__args__"):
            target_type = field.type.__args__[0]

        ap.add_argument(
            f"--{field.name.replace('_', '-')}",
            type=target_type,
            metavar=target_type.__name__,
            help="[auto generated from config class]",
        )

    def __init__(
        self,
        full_argparse_namespace,
        create: bool,
        collab_name: str,
        enable: t.Optional[int],
        only_signal_types: t.Optional[t.List[SignalType]],
        not_signal_types: t.Optional[t.List[SignalType]],
        only_owners: t.Optional[t.List[int]],
        not_owners: t.Optional[t.List[str]],
        only_tags: t.Optional[t.List[str]],
        not_tags: t.Optional[t.List[str]],
        is_json: bool,
    ) -> None:
        self.namespace = full_argparse_namespace
        self.create = create
        self.edit_kwargs = {}
        self.collab_name = collab_name
        if is_json:
            self.edit_kwargs = json.loads(collab_name)
            self.collab_name = self.edit_kwargs["name"]

        # Technically you could combine the flags and JSON, but you'd be weird
        if create:
            self.edit_kwargs["name"] = collab_name
            self.edit_kwargs["enabled"] = True
            self.edit_kwargs["api"] = self._API_CLS.get_name()

        if enable is not None:
            self.edit_kwargs["enabled"] = bool(enable)

        if only_signal_types is not None or create:
            self.edit_kwargs["only_signal_types"] = {
                s.get_name() for s in only_signal_types or ()
            }
        if not_signal_types is not None or create:
            self.edit_kwargs["not_signal_types"] = {
                s.get_name() for s in not_signal_types or ()
            }
        if only_owners is not None or create:
            self.edit_kwargs["only_owners"] = set(only_owners or ())
        if not_owners is not None or create:
            self.edit_kwargs["not_owners"] = set(not_owners or ())
        if only_tags is not None or create:
            self.edit_kwargs["only_tags"] = set(only_tags or ())
        if not_tags is not None or create:
            self.edit_kwargs["not_tags"] = set(not_tags or ())

        for field in fields(self._API_CLS.get_config_class()):
            if not field.init:
                continue
            if field.name in self._IGNORE_FIELDS:
                continue
            val = getattr(full_argparse_namespace, field.name)
            if val is not None:
                self.edit_kwargs[field.name] = val

    def execute(self, settings: CLISettings) -> None:
        existing = settings.get_collab(self.collab_name)

        if existing:
            if self.create:
                raise CommandError(
                    f'there\'s an existing collaboration named "{self.collab_name}"', 2
                )
            if existing.api != self._API_CLS.get_name():
                raise CommandError(
                    f"the existing collab is for the {existing.api} api, delete that one first",
                    2,
                )
            assert (
                existing.__class__ == self._API_CLS.get_config_class()
            ), "api name the same, but class different?"
            for name, val in self.edit_kwargs.items():
                setattr(existing, name, val)
            settings._state.update_collab(existing)
        elif self.create:
            logging.debug("Creating config with args: %s", self.edit_kwargs)
            to_create = self._API_CLS.get_config_class()(**self.edit_kwargs)
            settings._state.update_collab(to_create)
        else:
            raise CommandError("no such config! Did you mean to use --create?", 2)


class ConfigCollabForAPICommand(command_base.CommandWithSubcommands):
    """Create and edit collaborations for APIs"""

    @classmethod
    def get_name(cls) -> str:
        return "edit"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        apis = settings.get_fetchers()
        cls._SUBCOMMANDS = [
            cls._create_command_for_api(api)
            for api in apis
            if api.__class__ is not StaticSampleSignalExchangeAPI
        ]

    @classmethod
    def _create_command_for_api(
        cls, api: SignalExchangeAPI
    ) -> t.Type[command_base.Command]:
        """Don't try this at home!"""

        class _GeneratedUpdateCommand(_UpdateCollabCommand):
            _API_CLS = api.__class__

        _GeneratedUpdateCommand.__name__ = (
            f"{_GeneratedUpdateCommand.__name__}_{api.get_name()}"
        )

        return _GeneratedUpdateCommand


class ConfigCollabDeleteCommand(command_base.Command):
    """Delete collaborations"""

    @classmethod
    def get_name(cls) -> str:
        return "delete"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        ap.add_argument("collab_name", help="the collab to delete")

    def __init__(self, collab_name: str) -> None:
        self.collab_name = collab_name

    def execute(self, settings: CLISettings) -> None:
        collab = settings.get_collab(self.collab_name)
        if collab is None:
            raise CommandError("No such collab", 2)
        settings._state.delete_collab(collab)  # TODO clean private member access


class ConfigCollabCommand(command_base.CommandWithSubcommands):
    """Configure collaborations"""

    _SUBCOMMANDS = [
        ConfigCollabListCommand,
        ConfigCollabForAPICommand,
        ConfigCollabDeleteCommand,
    ]

    @classmethod
    def get_name(cls) -> str:
        return "collab"

    def execute(self, settings: CLISettings) -> None:
        ConfigCollabListCommand().execute(settings)


class ConfigExtensionsCommand(command_base.Command):
    """Configure extensions"""

    @classmethod
    def get_name(cls) -> str:
        return "extensions"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        ap.add_argument(
            "action",
            choices=["list", "add", "remove"],
            default="list",
            help="what to do",
        )
        ap.add_argument(
            "module",
            nargs="?",
            help="the module path to the extension. foo.bar.baz",
        )

    def __init__(self, action: str, module: t.Optional[str]) -> None:
        self.action = {
            "list": self.execute_list,
            "add": self.execute_add,
            "remove": self.execute_remove,
        }[action]
        self.module = module

    def execute(self, settings: CLISettings) -> None:
        self.action(settings)

    def execute_list(self, settings: CLISettings) -> None:
        if self.module:
            manifest = self.get_manifest(self.module)
            self.print_extension(manifest)
            return
        for module_name in sorted(settings.get_persistent_config().extensions):
            print(module_name)
            manifest = self.get_manifest(module_name)
            self.print_extension(manifest, indent=2)

    def get_manifest(self, module_name: str) -> ThreatExchangeExtensionManifest:
        try:
            return ThreatExchangeExtensionManifest.load_from_module_name(module_name)
        except ValueError as ve:
            raise CommandError(str(ve), 2)

    def execute_add(self, settings: CLISettings) -> None:
        if not self.module:
            raise CommandError("module is required", 2)

        manifest = self.get_manifest(self.module)

        # Validate our new setups by pretending to create a new mapping with the new classes
        content_and_settings = tx_meta.SignalTypeMapping(
            list(
                itertools.chain(
                    settings.get_all_content_types(), manifest.content_types
                )
            ),
            list(
                itertools.chain(settings.get_all_signal_types(), manifest.signal_types)
            ),
        )

        # For APIs, we also need to make sure they can be instanciated without args for the CLI
        apis = []
        for new_api in manifest.apis:
            try:
                instance = new_api()
            except Exception as e:
                logging.exception(f"Failed to instanciante API {new_api.get_name()}")
                raise CommandError(
                    f"Not able to instanciate API {new_api.get_name()} - throws {e}"
                )
            apis.append(instance)
        apis.extend(settings.get_fetchers())
        tx_meta.FetcherMapping(apis)

        self.print_extension(manifest)

        config = settings.get_persistent_config()
        config.extensions.add(self.module)
        settings.set_persistent_config(config)

    def execute_remove(self, settings: CLISettings) -> None:
        if not self.module:
            raise CommandError("Which module you are remove is required", 2)
        config = settings.get_persistent_config()
        if self.module not in config.extensions:
            raise CommandError(f"You haven't added {self.module}", 2)
        config.extensions.remove(self.module)
        settings.set_persistent_config(config)

    def print_extension(
        self, manifest: ThreatExchangeExtensionManifest, indent=0
    ) -> None:
        space = " " * indent
        level2 = f"\n{space}  "
        if manifest.signal_types:
            print(f"{space}Signal:{level2}", end="")
            print(
                level2.join(
                    f"{s.get_name()} - {s.__name__}" for s in manifest.signal_types
                )
            )
        if manifest.content_types:
            print(f"{space}Content:{level2}", end="")
            print(
                level2.join(
                    f"{c.get_name()} - {c.__name__}" for c in manifest.content_types
                )
            )
        if manifest.apis:
            print(f"{space}Content:{level2}", end="")
            print(level2.join(f"{a.get_name()} - {a.__name__}" for a in manifest.apis))


class ConfigSignalCommand(command_base.Command):
    """Configure signal types"""

    @classmethod
    def get_name(cls) -> str:
        return "signal"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        ap.add_argument(
            "action",
            choices=["list"],
            default="list",
            help="what to do",
        )

    def __init__(self, action: str) -> None:
        self.action = {
            "list": self.execute_list,
        }[action]

    def execute(self, settings: CLISettings) -> None:
        self.action(settings)

    def execute_list(self, settings: CLISettings) -> None:
        collabs = settings.get_all_collabs()
        for api, name in sorted((c.api, c.name) for c in collabs):
            print(api, name)


class ConfigContentCommand(command_base.Command):
    """Configure content types"""

    @classmethod
    def get_name(cls) -> str:
        return "content"

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        ap.add_argument(
            "--list",
            action="store_true",
            help="list the names of Content Types (default action)",
        )

    def execute(self, settings: CLISettings) -> None:
        content_types = settings.get_all_content_types()
        for name in sorted(c.get_name() for c in content_types):
            print(name)


class ConfigThreatExchangeAPICommand(command_base.Command):
    """Configure apis"""

    @classmethod
    def get_name(cls) -> str:
        return FBThreatExchangeSignalExchangeAPI.get_name()

    @classmethod
    def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> None:
        ap.add_argument(
            "--api-token",
            help="set the default api token",
        )

    def __init__(self, api_token: t.Optional[str]) -> None:
        self.api_token = api_token

    def execute(self, settings: CLISettings) -> None:
        if self.api_token is not None:
            config = settings.get_persistent_config()
            config.fb_threatexchange_api_token = self.api_token
            settings.set_persistent_config(config)


class ConfigAPICommand(command_base.CommandWithSubcommands):
    """Configure apis"""

    _SUBCOMMANDS = [ConfigThreatExchangeAPICommand]

    @classmethod
    def get_name(cls) -> str:
        return "api"

    def execute(self, settings: CLISettings) -> None:
        apis = settings.get_fetchers()
        for name in sorted(a.get_name() for a in apis):
            print(name)


class ConfigCommand(command_base.CommandWithSubcommands):
    """Configure the CLI"""

    _SUBCOMMANDS = [
        ConfigCollabCommand,
        ConfigSignalCommand,
        ConfigContentCommand,
        ConfigAPICommand,
        ConfigExtensionsCommand,
    ]
