python-threatexchange/threatexchange/fetcher/simple/state.py (128 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from collections import defaultdict from dataclasses import dataclass, field import logging import typing as t from threatexchange.fetcher.fetch_api import SignalExchangeAPI from threatexchange.signal_type.signal_base import SignalType from threatexchange.fetcher import fetch_state from threatexchange.fetcher.collab_config import CollaborationConfigBase @dataclass class SimpleFetchedSignalMetadata(fetch_state.FetchedSignalMetadata): """ Simple dataclass for fetched data. Merge by addition rather than replacement. """ opinions: t.List[fetch_state.SignalOpinion] = field(default_factory=list) def get_as_opinions(self) -> t.List[fetch_state.SignalOpinion]: return self.opinions @classmethod def merge( cls, older: "SimpleFetchedSignalMetadata", newer: "SimpleFetchedSignalMetadata" ) -> "SimpleFetchedSignalMetadata": if not older.opinions: return newer by_owner = {o.owner: o for o in newer.opinions} return cls([by_owner.get(o.owner, o) for o in older.opinions]) @classmethod def get_trivial(cls): return cls([fetch_state.SignalOpinion.get_trivial()]) @dataclass class SimpleFetchDelta(fetch_state.FetchDeltaWithUpdateStream): """ Simple class for deltas. If the record is set to None, this indicates the record should be deleted if it exists. """ updates: t.Mapping[t.Tuple[str, str], t.Optional[fetch_state.FetchedSignalMetadata]] checkpoint: fetch_state.FetchCheckpointBase done: bool # powers has_more def record_count(self) -> int: return len(self.updates) def next_checkpoint(self) -> fetch_state.FetchCheckpointBase: return self.checkpoint def has_more(self) -> bool: return not self.done def get_as_update_dict( self, ) -> t.Mapping[t.Tuple[str, str], t.Optional[fetch_state.FetchedSignalMetadata]]: return self.updates @dataclass class _StateTracker: updates_by_type: t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]] checkpoint: t.Optional[fetch_state.FetchCheckpointBase] dirty: bool = False def merge(self, newer: fetch_state.FetchDeltaWithUpdateStream) -> None: updates = newer.get_as_update_dict() if not updates: return newer_by_type: t.DefaultDict[ str, t.List[t.Tuple[str, t.Optional[fetch_state.FetchedSignalMetadata]]] ] = defaultdict(list) for (stype, signal_str), record in updates.items(): newer_by_type[stype].append((signal_str, record)) for n_type, n_updates in newer_by_type.items(): o_updates = self.updates_by_type.setdefault(n_type, {}) for sig_str, new_record in n_updates: if new_record is None: o_updates.pop(sig_str, None) else: old_record = o_updates.get(sig_str) if old_record: new_record = new_record.merge_metadata(old_record, new_record) o_updates[sig_str] = new_record self.checkpoint = newer.next_checkpoint() self.dirty = True class SimpleFetchedStateStore(fetch_state.FetchedStateStoreBase): """ Standardizes on merging on (type, indicator), merges in memory. """ def __init__( self, api_cls: t.Type[SignalExchangeAPI], ) -> None: self.api_cls = api_cls self._state: t.Dict[str, _StateTracker] = {} def _read_state( self, collab_name: str, ) -> t.Optional[ t.Tuple[ t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]], t.Optional[fetch_state.FetchCheckpointBase], ] ]: raise NotImplementedError def _write_state( self, collab_name: str, updates_by_type: t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]], checkpoint: fetch_state.FetchCheckpointBase, ) -> None: raise NotImplementedError def get_checkpoint( self, collab: CollaborationConfigBase ) -> t.Optional[fetch_state.FetchCheckpointBase]: return self._get_state(collab.name).checkpoint def _get_state(self, collab_name: str) -> _StateTracker: if collab_name not in self._state: read_state = self._read_state(collab_name) or ({}, None) ret = _StateTracker(*read_state) self._state[collab_name] = ret return ret return self._state[collab_name] def merge( # type: ignore[override] # fix with generics on base self, collab: CollaborationConfigBase, delta: fetch_state.FetchDeltaWithUpdateStream, ) -> None: """ Merge a FetchDeltaBase into the state. At the implementation's discretion, it may call flush() or the equivalent work. """ state = self._get_state(collab.name) if delta.record_count() == 0 and delta.next_checkpoint() in ( None, state.checkpoint, ): logging.warning("No op update for %s", collab.name) return state.merge(delta) def flush(self): for collab_name, state in self._state.items(): if state.dirty: assert state.checkpoint self._write_state(collab_name, state.updates_by_type, state.checkpoint) state.dirty = False def get_for_signal_type( self, collabs: t.List[CollaborationConfigBase], signal_type: t.Type[SignalType] ) -> t.Dict[str, t.Dict[str, fetch_state.FetchedSignalMetadata]]: st_name = signal_type.get_name() ret = {} for collab in collabs: state = self._get_state(collab.name) ret[collab.name] = state.updates_by_type.get(st_name, {}) return ret