# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import boto3
import bottle
import functools
import dataclasses

from dataclasses import dataclass, asdict
from mypy_boto3_dynamodb.service_resource import Table
import typing as t
from enum import Enum
from urllib.error import HTTPError
from mypy_boto3_sqs.client import SQSClient

from threatexchange.descriptor import ThreatDescriptor
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.video import VideoContent
from threatexchange.content_type.meta import (
    get_signal_types_by_name,
    get_content_types_by_name,
)

from hmalib.common.models.pipeline import MatchRecord
from hmalib.common.models.signal import (
    ThreatExchangeSignalMetadata,
    PendingThreatExchangeOpinionChange,
)
from hmalib.common.logging import get_logger
from hmalib.common.messages.match import BankedSignal
from hmalib.common.messages.writeback import WritebackMessage
from hmalib.indexers.metadata import (
    BANKS_SOURCE_SHORT_CODE,
    BankedSignalIndexMetadata,
)
from hmalib.lambdas.api.middleware import (
    jsoninator,
    JSONifiable,
    DictParseable,
    SubApp,
)
from hmalib.common.config import HMAConfig
from hmalib.common.models.bank import BankMember, BanksTable
from hmalib.matchers.matchers_base import Matcher

from hmalib.hashing.unified_hasher import UnifiedHasher
from hmalib.common.content_sources import URLContentSource


logger = get_logger(__name__)


@functools.lru_cache(maxsize=None)
def _get_sqs_client() -> SQSClient:
    return boto3.client("sqs")


_matcher = None


def _get_matcher(index_bucket_name: str, banks_table: BanksTable) -> Matcher:
    global _matcher
    if _matcher is None:
        _matcher = Matcher(
            index_bucket_name=index_bucket_name,
            supported_signal_types=[PdqSignal, VideoMD5Signal],
            banks_table=banks_table,
        )

    return _matcher


_hasher = None


def _get_hasher() -> UnifiedHasher:
    global _hasher
    if _hasher is None:
        _hasher = UnifiedHasher(
            supported_content_types=[PhotoContent, VideoContent],
            supported_signal_types=[PdqSignal, VideoMD5Signal],
            output_queue_url="",
        )

    return _hasher


@dataclass
class MatchSummary(JSONifiable):
    content_id: str
    signal_id: t.Union[str, int]
    signal_source: str
    updated_at: str

    def to_json(self) -> t.Dict:
        return asdict(self)


@dataclass
class MatchSummariesResponse(JSONifiable):
    match_summaries: t.List[MatchSummary]

    def to_json(self) -> t.Dict:
        return {
            "match_summaries": [summary.to_json() for summary in self.match_summaries]
        }


@dataclass
class BankedSignalDetailsMetadata(JSONifiable):
    bank_member_id: str
    bank_id: str

    def to_json(self) -> t.Dict:
        return asdict(self)


@dataclass
class ThreatExchangeSignalDetailsMetadata(JSONifiable):
    privacy_group_id: str
    tags: t.List[str]
    opinion: str
    pending_opinion_change: str

    def to_json(self) -> t.Dict:
        return asdict(self)


@dataclass
class MatchDetail(JSONifiable):
    """
    Note: te_signal_details should eventaully be folded into banked_signal_details
    once threatexchanges signals function the same as locally banked one.
    """

    content_id: str
    content_hash: str
    signal_id: t.Union[str, int]
    signal_hash: str
    signal_source: str
    signal_type: str
    updated_at: str
    match_distance: t.Optional[int]
    te_signal_details: t.List[ThreatExchangeSignalDetailsMetadata]
    banked_signal_details: t.List[BankedSignalDetailsMetadata]

    def to_json(self) -> t.Dict:
        result = asdict(self)
        result.update(
            te_signal_details=[datum.to_json() for datum in self.te_signal_details]
        )
        result.update(
            banked_signal_details=[
                datum.to_json() for datum in self.banked_signal_details
            ]
        )
        return result


@dataclass
class MatchDetailsResponse(JSONifiable):
    match_details: t.List[MatchDetail]

    def to_json(self) -> t.Dict:
        return {"match_details": [detail.to_json() for detail in self.match_details]}


@dataclass
class ChangeSignalOpinionResponse(JSONifiable):
    success: bool

    def to_json(self) -> t.Dict:
        return {"change_requested": self.success}


def get_match_details(
    datastore_table: Table, banks_table: BanksTable, content_id: str
) -> t.List[MatchDetail]:
    if not content_id:
        return []

    records = MatchRecord.get_from_content_id(datastore_table, f"{content_id}")

    return [
        MatchDetail(
            content_id=record.content_id,
            content_hash=record.content_hash,
            signal_id=record.signal_id,
            signal_hash=record.signal_hash,
            signal_source=record.signal_source,
            signal_type=record.signal_type.get_name(),
            updated_at=record.updated_at.isoformat(),
            match_distance=int(record.match_distance)
            if record.match_distance is not None
            else None,
            te_signal_details=get_te_signal_details(
                datastore_table=datastore_table,
                signal_id=record.signal_id,
                signal_source=record.signal_source,
            ),
            banked_signal_details=get_banked_signal_details(
                banks_table=banks_table,
                signal_id=record.signal_id,
                signal_source=record.signal_source,
            ),
        )
        for record in records
    ]


def get_te_signal_details(
    datastore_table: Table,
    signal_id: str,
    signal_source: str,
) -> t.List[ThreatExchangeSignalDetailsMetadata]:
    """
    Note: te_signal_details should eventaully be folded into banked_signal_details
    once threatexchanges signals function the same as locally banked one.
    """
    if not signal_id or not signal_source or signal_source == BANKS_SOURCE_SHORT_CODE:
        return []

    return [
        ThreatExchangeSignalDetailsMetadata(
            privacy_group_id=metadata.privacy_group_id,
            tags=[
                tag for tag in metadata.tags if tag not in ThreatDescriptor.SPECIAL_TAGS
            ],
            opinion=get_opinion_from_tags(metadata.tags).value,
            pending_opinion_change=metadata.pending_opinion_change.value,
        )
        for metadata in ThreatExchangeSignalMetadata.get_from_signal(
            datastore_table, signal_id
        )
    ]


def get_banked_signal_details(
    banks_table: BanksTable,
    signal_id: str,
    signal_source: str,
) -> t.List[BankedSignalDetailsMetadata]:
    if not signal_id or not signal_source or signal_source != BANKS_SOURCE_SHORT_CODE:
        return []

    return [
        BankedSignalDetailsMetadata(
            bank_member_id=bank_member_signal.bank_member_id,
            bank_id=bank_member_signal.bank_id,
        )
        for bank_member_signal in banks_table.get_bank_member_signal_from_id(signal_id)
    ]


class OpinionString(Enum):
    TRUE_POSITIVE = "True Positive"
    FALSE_POSITIVE = "False Positive"
    DISPUTED = "Disputed"
    UNKNOWN = "Unknown"


def get_opinion_from_tags(tags: t.List[str]) -> OpinionString:
    # see python-threatexchange descriptor.py for origins
    if ThreatDescriptor.TRUE_POSITIVE in tags:
        return OpinionString.TRUE_POSITIVE
    if ThreatDescriptor.FALSE_POSITIVE in tags:
        return OpinionString.FALSE_POSITIVE
    if ThreatDescriptor.DISPUTED in tags:
        return OpinionString.DISPUTED
    return OpinionString.UNKNOWN


@dataclass
class MatchesForHashRequest(DictParseable):
    signal_value: str
    signal_type: t.Type[SignalType]

    @classmethod
    def from_dict(cls, d):
        base = cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)})  # type: ignore # tiny hack to get convert string fields
        base.signal_type = get_signal_types_by_name()[base.signal_type]  # type: ignore # tiny hack to get SignalType from string in request
        return base


@dataclass
class MatchesForMediaRequest(DictParseable):
    content_url: str
    content_type: t.Type[ContentType]

    @classmethod
    def from_dict(cls, d):
        base = cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)})  # type: ignore # tiny hack to get convert string fields
        base.content_type = get_content_types_by_name()[base.content_type]  # type: ignore # tiny hack to get ContentType from string in request
        return base


@dataclass
class MatchesForHash(JSONifiable):
    match_distance: int

    # TODO: Once ThreatExchange data flows into Banks, we can Use BankMember
    # alone.
    matched_signal: t.Union[
        ThreatExchangeSignalMetadata, BankMember
    ]  # or matches signal from other sources

    UNSUPPORTED_FIELDS = ["updated_at", "pending_opinion_change"]

    def to_json(self) -> t.Dict:
        return {
            "match_distance": self.match_distance,
            "matched_signal": self._remove_unsupported_fields(
                self.matched_signal.to_json()
            ),
        }

    @classmethod
    def _remove_unsupported_fields(cls, matched_signal: t.Dict) -> t.Dict:
        """
        ThreatExchangeSignalMetadata is used to store metadata in dynamodb
        and handle opinion changes on said signal. However the request this object
        responds to only handles directly accessing the index. Because of this
        not all fields of the object are relevant or accurate.
        """
        for field in cls.UNSUPPORTED_FIELDS:
            try:
                del matched_signal[field]
            except KeyError:
                pass
        return matched_signal


@dataclass
class MatchesForHashResponse(JSONifiable):
    matches: t.List[MatchesForHash]

    def to_json(self) -> t.Dict:
        return {"matches": [match.to_json() for match in self.matches]}


@dataclass
class MatchesForMediaResponse(JSONifiable):
    signal_to_matches: t.Dict[str, t.Dict[str, t.List[MatchesForHash]]]
    # example: { "pdq" : { "<hash>" : [<matches?>] } }

    def to_json(self) -> t.Dict:
        return {
            "signal_to_matches": {
                signal_type: {signal_val: [match.to_json() for match in matches]}
                for (signal_val, matches) in hash_to_match.items()
            }
            for (signal_type, hash_to_match) in self.signal_to_matches.items()
        }


@dataclass
class MediaFetchError(JSONifiable):
    def to_json(self) -> t.Dict:
        return {
            "message": "Failed to fetch media from provided url",
        }


def get_matches_api(
    datastore_table: Table,
    hma_config_table: str,
    indexes_bucket_name: str,
    writeback_queue_url: str,
    bank_table: Table,
) -> bottle.Bottle:
    """
    A Closure that includes all dependencies that MUST be provided by the root
    API that this API plugs into. Declare dependencies here, but initialize in
    the root API alone.
    """

    # A prefix to all routes must be provided by the api_root app
    # The documentation below expects prefix to be '/matches/'
    matches_api = SubApp()
    HMAConfig.initialize(hma_config_table)

    banks_table = BanksTable(table=bank_table)

    @matches_api.get("/", apply=[jsoninator])
    def matches() -> MatchSummariesResponse:
        """
        Return all, or a filtered list of matches based on query params.
        """
        signal_q = bottle.request.query.signal_q or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
        signal_source = bottle.request.query.signal_source or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
        content_q = bottle.request.query.content_q or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``

        if content_q:
            records = MatchRecord.get_from_content_id(datastore_table, content_q)
        elif signal_q:
            records = MatchRecord.get_from_signal(
                datastore_table, signal_q, signal_source or ""
            )
        else:
            # TODO: Support pagination after implementing in UI.
            records = MatchRecord.get_recent_items_page(datastore_table).items

        return MatchSummariesResponse(
            match_summaries=[
                MatchSummary(
                    content_id=record.content_id,
                    signal_id=record.signal_id,
                    signal_source=record.signal_source,
                    updated_at=record.updated_at.isoformat(),
                )
                for record in records
            ]
        )

    @matches_api.get("/match/", apply=[jsoninator])
    def match_details() -> MatchDetailsResponse:
        """
        Return the match details for a given content id.
        """
        results = []
        if content_id := bottle.request.query.content_id or None:  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
            results = get_match_details(
                datastore_table=datastore_table,
                banks_table=banks_table,
                content_id=content_id,
            )
        return MatchDetailsResponse(match_details=results)

    @matches_api.post("/request-signal-opinion-change/", apply=[jsoninator])
    def request_signal_opinion_change() -> ChangeSignalOpinionResponse:
        """
        Request a change to the opinion for a signal in a given privacy_group.
        """
        signal_id = bottle.request.query.signal_id or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
        signal_source = bottle.request.query.signal_source or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
        privacy_group_id = bottle.request.query.privacy_group_id or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
        opinion_change = bottle.request.query.opinion_change or None  # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``

        if (
            not signal_id
            or not signal_source
            or not privacy_group_id
            or not opinion_change
        ):
            return ChangeSignalOpinionResponse(False)

        signal_id = str(signal_id)
        pending_opinion_change = PendingThreatExchangeOpinionChange(opinion_change)

        writeback_message = WritebackMessage.from_banked_signal_and_opinion_change(
            BankedSignal(signal_id, privacy_group_id, signal_source),
            pending_opinion_change,
        )
        writeback_message.send_to_queue(_get_sqs_client(), writeback_queue_url)
        logger.info(
            f"Opinion change enqueued for {signal_source}:{signal_id} in {privacy_group_id} change={opinion_change}"
        )

        signal = ThreatExchangeSignalMetadata.get_from_signal_and_privacy_group(
            datastore_table, signal_id=signal_id, privacy_group_id=privacy_group_id
        )

        if not signal:
            logger.error("Signal not found.")

        signal = t.cast(ThreatExchangeSignalMetadata, signal)
        signal.pending_opinion_change = pending_opinion_change
        success = signal.update_pending_opinion_change_in_table_if_exists(
            datastore_table
        )

        if not success:
            logger.error(f"Attempting to update {signal} in db failed")

        return ChangeSignalOpinionResponse(success)

    def _matches_for_hash(
        signal_type: t.Type[SignalType], signal_value: str
    ) -> t.List[MatchesForHash]:
        matches = _get_matcher(indexes_bucket_name, banks_table=banks_table).match(
            signal_type, signal_value
        )

        match_objects: t.List[MatchesForHash] = []

        # First get all threatexchange objects
        for match in matches:
            match_objects.extend(
                [
                    MatchesForHash(
                        match_distance=int(match.distance),
                        matched_signal=signal_metadata,
                    )
                    for signal_metadata in Matcher.get_te_metadata_objects_from_match(
                        signal_type, match
                    )
                ]
            )

        # now get all bank objects
        for match in matches:
            for metadata_obj in filter(
                lambda m: m.get_source() == BANKS_SOURCE_SHORT_CODE, match.metadata
            ):
                metadata_obj = t.cast(BankedSignalIndexMetadata, metadata_obj)
                match_objects.append(
                    MatchesForHash(
                        match_distance=int(match.distance),
                        matched_signal=banks_table.get_bank_member(
                            metadata_obj.bank_member_id
                        ),
                    )
                )

        return match_objects

    @matches_api.get(
        "/for-hash/", apply=[jsoninator(MatchesForHashRequest, from_query=True)]
    )
    def for_hash(request: MatchesForHashRequest) -> MatchesForHashResponse:
        """
        For a given hash/signal check the index(es) for matches and return the details.

        This does not change system state, metadata returned will not be written any tables
        unlike when matches are found for submissions.
        """

        return MatchesForHashResponse(
            matches=_matches_for_hash(request.signal_type, request.signal_value)
        )

    @matches_api.post("/for-media/", apply=[jsoninator(MatchesForMediaRequest)])
    def for_media(
        request: MatchesForMediaRequest,
    ) -> t.Union[MatchesForMediaResponse, MediaFetchError]:
        """
        For a given piece of media hash it, check the index(es) for matches, and return the details.

        This does not change system state, metadata returned will not be written any tables
        unlike when matches are found for submissions.
        """
        try:
            bytes_: bytes = URLContentSource().get_bytes(request.content_url)
        except Exception as e:
            bottle.response.status = 400
            return MediaFetchError()

        signal_to_matches = {}
        for signal in _get_hasher().get_hashes(request.content_type, bytes_):
            signal_to_matches[signal.signal_type.get_name()] = {
                signal.signal_value: _matches_for_hash(
                    signal_type=signal.signal_type, signal_value=signal.signal_value
                )
            }

        return MatchesForMediaResponse(signal_to_matches=signal_to_matches)

    return matches_api
