python-threatexchange/threatexchange/cli/cli_state.py (150 lines of code) (raw):

#!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ A wrapper around loading and storing ThreatExchange data from files. There are a few categories of state that this wraps: 1. Checkpoints - state about previous fetches 2. Collaboration Indicator Dumps - Raw output from threat_updates 3. Index state - serializations of indexes for SignalType """ from enum import Enum import json import pathlib import typing as t import dataclasses import logging import dacite from threatexchange.signal_type.index import SignalTypeIndex from threatexchange.signal_type.signal_base import SignalType from threatexchange.cli.exceptions import CommandError from threatexchange.fetcher.collab_config import CollaborationConfigBase from threatexchange.fetcher.fetch_state import ( FetchCheckpointBase, FetchedSignalMetadata, ) from threatexchange.fetcher.simple import state as simple_state from threatexchange.fetcher.fetch_api import SignalExchangeAPI from threatexchange.signal_type import signal_base from threatexchange.signal_type import index class CliIndexStore: """ Persistance layer for SignalTypeIndex objects for the cli. They are just stored to a file directory, with names based on their type. """ FILE_EXTENSION = ".index" def __init__(self, indice_dir: pathlib.Path) -> None: self.dir = indice_dir def get_available(self) -> t.List[str]: """Return the names (SignalType.get_name()) of stored indices""" return [ str(f)[: -len(self.FILE_EXTENSION)] for f in self.dir.glob(f"*{self.FILE_EXTENSION}") ] def clear( self, only_types: t.Optional[t.Iterable[t.Type[SignalType]]] = None ) -> None: """Clear persisted indices""" only_names = None if only_types is not None: only_names = {st.get_name() for st in only_types} for file in self.dir.glob(f"*{self.FILE_EXTENSION}"): if ( only_names is None or str(file)[: -len(self.FILE_EXTENSION)] in only_names ): logging.info("Removing index %s", file) file.unlink() def _index_file(self, signal_type: t.Type[signal_base.SignalType]) -> pathlib.Path: """The expected path for the index for a signal type""" return self.dir / f"{signal_type.get_name()}{self.FILE_EXTENSION}" def store_index( self, signal_type: t.Type[signal_base.SignalType], index: SignalTypeIndex ) -> None: """Persist a SignalTypeIndex to disk""" assert signal_type.get_index_cls() == index.__class__ path = self._index_file(signal_type) with path.open("wb") as fout: index.serialize(fout) def load_index( self, signal_type: t.Type[signal_base.SignalType] ) -> t.Optional[index.SignalTypeIndex]: """Load the SignalTypeIndex for this type from disk""" path = self._index_file(signal_type) if not path.exists(): return None with path.open("rb") as fin: return signal_type.get_index_cls().deserialize(fin) class CliSimpleState(simple_state.SimpleFetchedStateStore): """ A simple on-disk storage format for the CLI. Ideally, it should be easy to read manually (for debugging), but compact enough to handle very large sets of data. """ JSON_CHECKPOINT_KEY = "checkpoint" JSON_RECORDS_KEY = "records" def __init__( self, api_cls: t.Type[SignalExchangeAPI], fetched_state_dir: pathlib.Path ) -> None: super().__init__(api_cls) self.dir = fetched_state_dir def collab_file(self, collab_name: str) -> pathlib.Path: """The file location for collaboration state""" return self.dir / f"{collab_name}.state.json" def clear(self, collab: CollaborationConfigBase) -> None: """Delete a collaboration and its state directory""" file = self.collab_file(collab.name) if file.is_file(): logging.info("Removing %s", file) file.unlink(missing_ok=True) if file.parent.is_dir(): if next(file.parent.iterdir(), None) is None: logging.info("Removing directory %s", file.parent) file.parent.rmdir() def _read_state( self, collab_name: str, ) -> t.Optional[ t.Tuple[ t.Dict[str, t.Dict[str, FetchedSignalMetadata]], FetchCheckpointBase, ] ]: file = self.collab_file(collab_name) if not file.is_file(): return None try: with file.open("r") as f: json_dict = json.load(f) checkpoint = dacite.from_dict( data_class=self.api_cls.get_checkpoint_cls(), data=json_dict[self.JSON_CHECKPOINT_KEY], config=dacite.Config(cast=[Enum]), ) records = json_dict[self.JSON_RECORDS_KEY] # Minor stab at lowering memory footprint by converting kinda # inline for stype in list(records): records[stype] = { signal: dacite.from_dict( data_class=self.api_cls.get_record_cls(), data=json_record, config=dacite.Config(cast=[Enum]), ) for signal, json_record in records[stype].items() } return records, checkpoint except Exception: logging.exception("Failed to read state for %s", collab_name) raise CommandError( f"Failed to read state for {collab_name}. " "You might have to delete it with `threatexchange fetch --clear`" ) def _write_state( # type: ignore[override] # fix with generics on base self, collab_name: str, updates_by_type: t.Dict[str, t.Dict[str, FetchedSignalMetadata]], checkpoint: FetchCheckpointBase, ) -> None: file = self.collab_file(collab_name) if not file.parent.exists(): file.parent.mkdir(parents=True) record_sanity_check = next( ( record for records in updates_by_type.values() for record in records.values() ), None, ) if record_sanity_check is not None: assert ( # Not isinstance - we want exactly this class record_sanity_check.__class__ == self.api_cls.get_record_cls() ), ( f"Record cls: want {self.api_cls.get_record_cls().__name__} " f"got {record_sanity_check.__class__.__name__}" ) json_dict = { self.JSON_CHECKPOINT_KEY: dataclasses.asdict(checkpoint), self.JSON_RECORDS_KEY: { stype: { s: dataclasses.asdict(record) for s, record in signal_to_record.items() } for stype, signal_to_record in updates_by_type.items() }, } with file.open("w") as f: json.dump(json_dict, f, indent=2)