# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


from collections import defaultdict
from dataclasses import dataclass, field
import logging
import typing as t
from threatexchange.fetcher.fetch_api import SignalExchangeAPI

from threatexchange.signal_type.signal_base import SignalType
from threatexchange.fetcher import fetch_state
from threatexchange.fetcher.collab_config import CollaborationConfigBase


@dataclass
class SimpleFetchedSignalMetadata(fetch_state.FetchedSignalMetadata):
    """
    Simple dataclass for fetched data.

    Merge by addition rather than replacement.
    """

    opinions: t.List[fetch_state.SignalOpinion] = field(default_factory=list)

    def get_as_opinions(self) -> t.List[fetch_state.SignalOpinion]:
        return self.opinions

    @classmethod
    def merge(
        cls, older: "SimpleFetchedSignalMetadata", newer: "SimpleFetchedSignalMetadata"
    ) -> "SimpleFetchedSignalMetadata":
        if not older.opinions:
            return newer

        by_owner = {o.owner: o for o in newer.opinions}
        return cls([by_owner.get(o.owner, o) for o in older.opinions])

    @classmethod
    def get_trivial(cls):
        return cls([fetch_state.SignalOpinion.get_trivial()])


@dataclass
class SimpleFetchDelta(fetch_state.FetchDeltaWithUpdateStream):
    """
    Simple class for deltas.

    If the record is set to None, this indicates the record should be
    deleted if it exists.
    """

    updates: t.Mapping[t.Tuple[str, str], t.Optional[fetch_state.FetchedSignalMetadata]]
    checkpoint: fetch_state.FetchCheckpointBase
    done: bool  # powers has_more

    def record_count(self) -> int:
        return len(self.updates)

    def next_checkpoint(self) -> fetch_state.FetchCheckpointBase:
        return self.checkpoint

    def has_more(self) -> bool:
        return not self.done

    def get_as_update_dict(
        self,
    ) -> t.Mapping[t.Tuple[str, str], t.Optional[fetch_state.FetchedSignalMetadata]]:
        return self.updates


@dataclass
class _StateTracker:
    updates_by_type: t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]]
    checkpoint: t.Optional[fetch_state.FetchCheckpointBase]
    dirty: bool = False

    def merge(self, newer: fetch_state.FetchDeltaWithUpdateStream) -> None:
        updates = newer.get_as_update_dict()
        if not updates:
            return
        newer_by_type: t.DefaultDict[
            str, t.List[t.Tuple[str, t.Optional[fetch_state.FetchedSignalMetadata]]]
        ] = defaultdict(list)
        for (stype, signal_str), record in updates.items():
            newer_by_type[stype].append((signal_str, record))

        for n_type, n_updates in newer_by_type.items():
            o_updates = self.updates_by_type.setdefault(n_type, {})
            for sig_str, new_record in n_updates:
                if new_record is None:
                    o_updates.pop(sig_str, None)
                else:
                    old_record = o_updates.get(sig_str)
                    if old_record:
                        new_record = new_record.merge_metadata(old_record, new_record)
                    o_updates[sig_str] = new_record
        self.checkpoint = newer.next_checkpoint()
        self.dirty = True


class SimpleFetchedStateStore(fetch_state.FetchedStateStoreBase):
    """
    Standardizes on merging on (type, indicator), merges in memory.
    """

    def __init__(
        self,
        api_cls: t.Type[SignalExchangeAPI],
    ) -> None:
        self.api_cls = api_cls
        self._state: t.Dict[str, _StateTracker] = {}

    def _read_state(
        self,
        collab_name: str,
    ) -> t.Optional[
        t.Tuple[
            t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]],
            t.Optional[fetch_state.FetchCheckpointBase],
        ]
    ]:
        raise NotImplementedError

    def _write_state(
        self,
        collab_name: str,
        updates_by_type: t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]],
        checkpoint: fetch_state.FetchCheckpointBase,
    ) -> None:
        raise NotImplementedError

    def get_checkpoint(
        self, collab: CollaborationConfigBase
    ) -> t.Optional[fetch_state.FetchCheckpointBase]:
        return self._get_state(collab.name).checkpoint

    def _get_state(self, collab_name: str) -> _StateTracker:
        if collab_name not in self._state:
            read_state = self._read_state(collab_name) or ({}, None)
            ret = _StateTracker(*read_state)
            self._state[collab_name] = ret
            return ret
        return self._state[collab_name]

    def merge(  # type: ignore[override]  # fix with generics on base
        self,
        collab: CollaborationConfigBase,
        delta: fetch_state.FetchDeltaWithUpdateStream,
    ) -> None:
        """
        Merge a FetchDeltaBase into the state.

        At the implementation's discretion, it may call flush() or the
        equivalent work.
        """

        state = self._get_state(collab.name)

        if delta.record_count() == 0 and delta.next_checkpoint() in (
            None,
            state.checkpoint,
        ):
            logging.warning("No op update for %s", collab.name)
            return

        state.merge(delta)

    def flush(self):
        for collab_name, state in self._state.items():
            if state.dirty:
                assert state.checkpoint
                self._write_state(collab_name, state.updates_by_type, state.checkpoint)
                state.dirty = False

    def get_for_signal_type(
        self, collabs: t.List[CollaborationConfigBase], signal_type: t.Type[SignalType]
    ) -> t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]]:
        st_name = signal_type.get_name()
        ret = {}
        for collab in collabs:
            state = self._get_state(collab.name)
            ret[collab.name] = state.updates_by_type.get(st_name, {})
        return ret
