# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import datetime
from decimal import Decimal
import typing as t
from dataclasses import dataclass, field

from mypy_boto3_dynamodb.service_resource import Table
from boto3.dynamodb.conditions import Key

from threatexchange.content_type.meta import get_signal_types_by_name
from threatexchange.signal_type.signal_base import SignalType
from hmalib.common.timebucketizer import CSViable

from hmalib.common.models.models_base import (
    DynamoDBItem,
    DynamoDBCursorKey,
    PaginatedResponse,
)

"""
Data transfer object classes to be used with dynamodbstore
Classes in this module should implement methods `to_dynamodb_item(self)` and
`to_sqs_message(self)`
"""


@dataclass
class PipelineRecordBase(DynamoDBItem):
    """
    Base Class for records of pieces of content going through the
    hashing/matching pipeline.
    """

    content_id: str

    signal_type: t.Type[SignalType]
    content_hash: str

    updated_at: datetime.datetime

    def to_dynamodb_item(self) -> dict:
        raise NotImplementedError

    def to_sqs_message(self) -> dict:
        raise NotImplementedError

    @classmethod
    def get_recent_items_page(
        cls, table: Table, ExclusiveStartKey: t.Optional[DynamoDBCursorKey] = None
    ) -> PaginatedResponse:
        """
        Get a paginated list of recent items. The API is purposefully kept
        """
        raise NotImplementedError


@dataclass
class PipelineRecordDefaultsBase:
    """
    Hash and match records may have signal_type specific attributes that are not
    universal. eg. PDQ hashes have quality and PDQ matches have distance while
    MD5 has neither. Assuming such signal_type specific attributes will not be
    indexed, we are choosing to put them into a bag of variables. See
    PipelineRecordBase.[de]serialize_signal_specific_attributes() to understand
    storage.

    Ideally, this would be an attribute with defaults, but that would make
    inheritance complicated because default_values would precede non-default
    values in the sub class.
    """

    signal_specific_attributes: t.Dict[str, t.Union[int, float, str]] = field(
        default_factory=dict
    )

    def serialize_signal_specific_attributes(self) -> dict:
        """
        Converts signal_specific_attributes into a dict. Uses the signal_type as
        a prefix.

        So for PDQ hash records, `item.signal_specific_attributes.quality` will
        become `item.pdq_quality`. Storing as top-level item attributes allows
        indexing if we need it later. You can't do that with nested elements.
        """
        # Note on Typing: PipelineRecordDefaultsBase is meant to be used with
        # PipelineRecordBase. So it will have access to all fields from
        # PipelineRecordBase. This is (impossible?) to express using mypy. So
        # ignore self.signal_type

        return {
            f"{self.signal_type.get_name()}_{key}": value  # type:ignore
            for key, value in self.signal_specific_attributes.items()
        }

    @staticmethod
    def _signal_specific_attribute_remove_prefix(prefix: str, k: str) -> str:
        return k[len(prefix) :]

    @classmethod
    def deserialize_signal_specific_attributes(
        cls, d: t.Dict[str, t.Any]
    ) -> t.Dict[str, t.Union[int, float, str]]:
        """
        Reverses serialize_signal_specific_attributes.
        """
        signal_type = d["SignalType"]
        signal_type_prefix = f"{signal_type}_"

        return {
            cls._signal_specific_attribute_remove_prefix(signal_type_prefix, key): value
            for key, value in d.items()
            if key.startswith(signal_type_prefix)
        }


@dataclass
class PipelineHashRecord(PipelineRecordDefaultsBase, PipelineRecordBase):
    """
    Successful execution at the hasher produces this record.
    """

    def to_dynamodb_item(self) -> dict:
        top_level_overrides = self.serialize_signal_specific_attributes()
        return dict(
            **top_level_overrides,
            **{
                "PK": self.get_dynamodb_content_key(self.content_id),
                "SK": self.get_dynamodb_type_key(self.signal_type.get_name()),
                "ContentHash": self.content_hash,
                "SignalType": self.signal_type.get_name(),
                "GSI2-PK": self.get_dynamodb_type_key(self.__class__.__name__),
                "UpdatedAt": self.updated_at.isoformat(),
            },
        )

    def to_legacy_sqs_message(self) -> dict:
        """
        Prior to supporting MD5, the hash message was simplistic and did not
        support all fields in the PipelineHashRecord. This is inconsistent with
        almost all other message models.

        We can remove this once pdq_hasher and pdq_matcher are removed.
        """
        return {
            "hash": self.content_hash,
            "type": self.signal_type.get_name(),
            "key": self.content_id,
        }

    def to_sqs_message(self) -> dict:
        return {
            "ContentId": self.content_id,
            "SignalType": self.signal_type.get_name(),
            "ContentHash": self.content_hash,
            "SignalSpecificAttributes": self.signal_specific_attributes,
            "UpdatedAt": self.updated_at.isoformat(),
        }

    @classmethod
    def from_sqs_message(cls, d: dict) -> "PipelineHashRecord":
        return cls(
            content_id=d["ContentId"],
            signal_type=get_signal_types_by_name()[d["SignalType"]],
            content_hash=d["ContentHash"],
            signal_specific_attributes=d["SignalSpecificAttributes"],
            updated_at=datetime.datetime.fromisoformat(d["UpdatedAt"]),
        )

    @classmethod
    def could_be(cls, d: dict) -> bool:
        """
        Return True if this dict can be converted to a PipelineHashRecord
        """
        return "ContentId" in d and "SignalType" in d and "ContentHash" in d

    @classmethod
    def get_from_content_id(
        cls,
        table: Table,
        content_id: str,
        signal_type: t.Optional[t.Type[SignalType]] = None,
    ) -> t.List["PipelineHashRecord"]:
        """
        Returns all available PipelineHashRecords for a content_id.
        """
        expected_pk = cls.get_dynamodb_content_key(content_id)

        if signal_type is None:
            condition_expression = Key("PK").eq(expected_pk) & Key("SK").begins_with(
                DynamoDBItem.TYPE_PREFIX
            )
        else:
            condition_expression = Key("PK").eq(expected_pk) & Key("SK").eq(
                DynamoDBItem.get_dynamodb_type_key(signal_type.get_name())
            )

        return cls._result_items_to_records(
            table.query(
                KeyConditionExpression=condition_expression,
            ).get("Items", [])
        )

    @classmethod
    def get_recent_items_page(
        cls, table: Table, exclusive_start_key: t.Optional[DynamoDBCursorKey] = None
    ) -> PaginatedResponse["PipelineHashRecord"]:
        """
        Get a paginated list of recent items.
        """
        if not exclusive_start_key:
            # Evidently, https://github.com/boto/boto3/issues/2813 boto is able
            # to distinguish fun(Parameter=None) from fun(). So, we can't use
            # exclusive_start_key's optionality. We have to do an if clause!
            # Fun!
            result = table.query(
                IndexName="GSI-2",
                ScanIndexForward=False,
                Limit=100,
                KeyConditionExpression=Key("GSI2-PK").eq(
                    DynamoDBItem.get_dynamodb_type_key(cls.__name__)
                ),
            )
        else:
            result = table.query(
                IndexName="GSI-2",
                ExclusiveStartKey=exclusive_start_key,
                ScanIndexForward=False,
                Limit=100,
                KeyConditionExpression=Key("GSI2-PK").eq(
                    DynamoDBItem.get_dynamodb_type_key(cls.__name__)
                ),
            )

        return PaginatedResponse(
            t.cast(DynamoDBCursorKey, result.get("LastEvaluatedKey", None)),
            cls._result_items_to_records(result["Items"]),
        )

    @classmethod
    def _result_items_to_records(
        cls,
        items: t.List[t.Dict],
    ) -> t.List["PipelineHashRecord"]:
        """
        Get a paginated list of recent hash records. Subsequent calls must use
        `return_value.last_evaluated_key`.
        """
        return [
            PipelineHashRecord(
                content_id=item["PK"][len(cls.CONTENT_KEY_PREFIX) :],
                signal_type=get_signal_types_by_name()[item["SignalType"]],
                content_hash=item["ContentHash"],
                updated_at=datetime.datetime.fromisoformat(item["UpdatedAt"]),
                signal_specific_attributes=cls.deserialize_signal_specific_attributes(
                    item
                ),
            )
            for item in items
        ]


@dataclass
class _MatchRecord(PipelineRecordBase):
    """
    Successful execution at the matcher produces this record.
    """

    signal_id: str
    signal_source: str
    signal_hash: str
    match_distance: t.Optional[int] = None


@dataclass
class MatchRecord(PipelineRecordDefaultsBase, _MatchRecord):
    """
    Weird, innit? You can't introduce non-default fields after default fields.
    All default fields in PipelineRecordBase are actually in
    PipelineRecordDefaultsBase and this complex inheritance chain allows you to
    create an MRO that is legal.

    H/T:
    https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
    """

    def to_dynamodb_item(self) -> dict:
        top_level_overrides = self.serialize_signal_specific_attributes()
        return dict(
            **top_level_overrides,
            **{
                "PK": self.get_dynamodb_content_key(self.content_id),
                "SK": self.get_dynamodb_signal_key(self.signal_source, self.signal_id),
                "ContentHash": self.content_hash,
                "UpdatedAt": self.updated_at.isoformat(),
                "SignalHash": self.signal_hash,
                "SignalSource": self.signal_source,
                "SignalType": self.signal_type.get_name(),
                "GSI1-PK": self.get_dynamodb_signal_key(
                    self.signal_source, self.signal_id
                ),
                "GSI1-SK": self.get_dynamodb_content_key(self.content_id),
                "HashType": self.signal_type.get_name(),
                "GSI2-PK": self.get_dynamodb_type_key(self.__class__.__name__),
                "MatchDistance": self.match_distance,
            },
        )

    def to_sqs_message(self) -> dict:
        # TODO add method for when matches are added to a sqs
        raise NotImplementedError

    @classmethod
    def get_from_content_id(
        cls, table: Table, content_id: str
    ) -> t.List["MatchRecord"]:
        """
        Return all matches for a content_id.
        """

        content_key = DynamoDBItem.get_dynamodb_content_key(content_id)
        source_prefix = DynamoDBItem.SIGNAL_KEY_PREFIX

        return cls._result_items_to_records(
            table.query(
                KeyConditionExpression=Key("PK").eq(content_key)
                & Key("SK").begins_with(source_prefix),
            ).get("Items", [])
        )

    @classmethod
    def get_from_signal(
        cls, table: Table, signal_id: t.Union[str, int], signal_source: str
    ) -> t.List["MatchRecord"]:
        """
        Return all matches for a signal. Needs source and id to uniquely
        identify a signal.
        """

        signal_key = DynamoDBItem.get_dynamodb_signal_key(signal_source, signal_id)

        return cls._result_items_to_records(
            table.query(
                IndexName="GSI-1",
                KeyConditionExpression=Key("GSI1-PK").eq(signal_key),
            ).get("Items", [])
        )

    @classmethod
    def get_recent_items_page(
        cls, table: Table, exclusive_start_key: t.Optional[DynamoDBCursorKey] = None
    ) -> PaginatedResponse["MatchRecord"]:
        """
        Get a paginated list of recent match records. Subsequent calls must use
        `return_value.last_evaluated_key`.
        """
        if not exclusive_start_key:
            # Evidently, https://github.com/boto/boto3/issues/2813 boto is able
            # to distinguish fun(Parameter=None) from fun(). So, we can't use
            # exclusive_start_key's optionality. We have to do an if clause!
            # Fun!
            result = table.query(
                IndexName="GSI-2",
                Limit=100,
                ScanIndexForward=False,
                KeyConditionExpression=Key("GSI2-PK").eq(
                    DynamoDBItem.get_dynamodb_type_key(cls.__name__)
                ),
            )
        else:
            result = table.query(
                IndexName="GSI-2",
                Limit=100,
                ExclusiveStartKey=exclusive_start_key,
                ScanIndexForward=False,
                KeyConditionExpression=Key("GSI2-PK").eq(
                    DynamoDBItem.get_dynamodb_type_key(cls.__name__)
                ),
            )

        return PaginatedResponse(
            t.cast(DynamoDBCursorKey, result.get("LastEvaluatedKey", None)),
            cls._result_items_to_records(result["Items"]),
        )

    @classmethod
    def _result_items_to_records(
        cls,
        items: t.List[t.Dict],
    ) -> t.List["MatchRecord"]:
        return [
            MatchRecord(
                content_id=cls.remove_content_key_prefix(item["PK"]),
                content_hash=item["ContentHash"],
                updated_at=datetime.datetime.fromisoformat(item["UpdatedAt"]),
                signal_type=get_signal_types_by_name()[item["SignalType"]],
                signal_id=cls.remove_signal_key_prefix(
                    item["SK"], item["SignalSource"]
                ),
                signal_source=item["SignalSource"],
                signal_hash=item["SignalHash"],
                signal_specific_attributes=cls.deserialize_signal_specific_attributes(
                    item
                ),
                match_distance=item.get("MatchDistance"),
            )
            for item in items
        ]


@dataclass(eq=True)
class HashRecord(CSViable):
    """
    We are getting these records, with content_hashes and content_ids from the hashing process with intent to build an PDQIndex
    """

    content_hash: str
    content_id: str

    def to_csv(self) -> t.List[t.Union[str, int]]:
        return [self.content_hash, self.content_id]

    @classmethod
    def from_csv(cls: t.Type["HashRecord"], value: t.List[str]) -> "HashRecord":
        return HashRecord(value[0], value[1])
