python-threatexchange/threatexchange/cli/fetch_cmd.py (197 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import collections import logging import datetime import logging import time import typing as t from threatexchange.cli.cli_config import CLISettings from threatexchange.cli.dataset_cmd import DatasetCommand from threatexchange.fetcher.collab_config import CollaborationConfigBase from threatexchange.fetcher.fetch_api import SignalExchangeAPI from threatexchange.fetcher.fetch_state import ( FetchCheckpointBase, FetchedStateStoreBase, ) from threatexchange.cli import command_base class FetchCommand(command_base.Command): """ Download content from signal exchange APIs to disk. """ PROGRESS_PRINT_INTERVAL_SEC = 30 @classmethod def init_argparse(cls, settings: CLISettings, ap) -> None: ap.add_argument( "--clear", action="store_true", help="delete fetched state and checkpoints " "(you almost never need to do this)", ) ap.add_argument( "--skip-index-rebuild", action="store_true", help="don't rebuild indices after fetch", ) ap.add_argument("--limit", type=int, help="stop after fetching this many items") ap.add_argument( "--time-limit-sec", type=int, metavar="SEC", help="stop fetching after this many seconds", ) ap.add_argument( "--only-api", choices=[f.get_name() for f in settings.get_fetchers()], help="only fetch from this API", ) ap.add_argument( "--only-collab", metavar="NAME", help="only fetch for this collaboration", ) def __init__( self, # Defaults to make it easier to call from match clear: bool = False, time_limit_sec: t.Optional[int] = None, limit: t.Optional[int] = None, skip_index_rebuild: bool = False, only_api: t.Optional[str] = None, only_collab: t.Optional[str] = None, ) -> None: self.clear = clear self.time_limit_sec = time_limit_sec self.limit = limit self.skip_index_rebuild = skip_index_rebuild self.only_api = only_api self.only_collab = only_collab self.collabs: t.List[CollaborationConfigBase] = [] # Limits self.total_fetched_count = 0 self.start_time = time.time() # Progress self.last_update_time: t.Optional[int] = None # Print first update after 5 seconds self.last_update_printed = time.time() - self.PROGRESS_PRINT_INTERVAL_SEC + 5 self.progress_fetched_count = 0 self.counts: t.Dict[str, int] = collections.Counter() def has_hit_limits(self): if self.limit is not None and self.total_fetched_count >= self.limit: return True if self.time_limit_sec is not None: if time.time() - self.start_time >= self.time_limit_sec: return True return False def execute(self, settings: CLISettings) -> None: fetchers = settings.get_fetchers() # Verify collab arguments self.collabs = settings.get_all_collabs(default_to_sample=True) if self.only_collab: self.collabs = [c for c in self.collabs if c.name == self.only_collab] if not self.collabs: raise command_base.CommandError( f"No such collab '{self.only_collab}'", 2 ) if all(not c.enabled for c in self.collabs): self.stderr("All collabs are disabled. Nothing to do.") return # Do work if self.clear: self.stderr("Clearing fetched state") for fetcher in settings.get_fetchers(): store = settings.get_fetch_store_for_fetcher(fetcher) for collab in self.collabs: if self.only_collab not in (None, collab.name): continue logging.info("Clearing %s - %s", fetcher.get_name(), collab.name) store.clear(collab) return all_succeeded = True any_succeded = False for fetcher in fetchers: logging.info("Fetching all %s's configs", fetcher.get_name()) succeeded = self.execute_for_fetcher(settings, fetcher) all_succeeded &= succeeded any_succeded |= succeeded if any_succeded and not self.skip_index_rebuild: self.stderr("Rebuilding match indices...") DatasetCommand().execute_generate_indices(settings) if not all_succeeded: raise command_base.CommandError("Some collabs had errors!", 3) def execute_for_fetcher( self, settings: CLISettings, fetcher: SignalExchangeAPI ) -> bool: success = True for collab in self.collabs: if collab.api != fetcher.get_name(): continue if not collab.enabled: logging.debug( "Skipping %s, disabled", ) continue fetch_ok = self.execute_for_collab(settings, fetcher, collab) success &= fetch_ok return success def execute_for_collab( self, settings: CLISettings, fetcher: SignalExchangeAPI, collab: CollaborationConfigBase, ) -> bool: store = settings.get_fetch_store_for_fetcher(fetcher.__class__) checkpoint = self._verify_store_and_checkpoint(store, collab) self.progress_fetched_count = 0 self.current_collab = collab.name self.current_api = fetcher.get_name() try: while not self.has_hit_limits(): delta = fetcher.fetch_once( settings.get_all_signal_types(), collab, checkpoint ) logging.info("Fetched %d records", delta.record_count()) checkpoint = delta.next_checkpoint() self._fetch_progress(delta.record_count(), checkpoint) assert checkpoint is not None # Infinite loop protection store.merge(collab, delta) if not delta.has_more(): break except: self._stderr_current("failed to fetch!") logging.exception("Failed to fetch %s", collab.name) return False finally: store.flush() self._print_progress(done=True) return True def _verify_store_and_checkpoint( self, store: FetchedStateStoreBase, collab: CollaborationConfigBase ) -> t.Optional[FetchCheckpointBase]: checkpoint = store.get_checkpoint(collab) if checkpoint is not None and checkpoint.is_stale(): store.clear(collab) return None return checkpoint def _fetch_progress(self, batch_size: int, checkpoint: FetchCheckpointBase) -> None: self.progress_fetched_count += batch_size self.total_fetched_count += batch_size progress_ts = checkpoint.get_progress_timestamp() if progress_ts is not None: self.last_update_time = progress_ts now = time.time() if now - self.last_update_printed >= self.PROGRESS_PRINT_INTERVAL_SEC: self.last_update_printed = now self._print_progress() def _stderr_current(self, msg: str) -> None: assert self.current_api and self.current_collab self.stderr( f"[{self.current_api}] {self.current_collab} - {msg}", ) def _print_progress(self, *, done=False): processed = "Syncing..." if done: processed = "Up to date" elif self.progress_fetched_count: processed = f"Downloaded {self.progress_fetched_count} updates" from_time = "" if self.last_update_time is not None: if not from_time: from_time = "ages long past" elif self.last_update_time >= time.time() - 1: from_time = "moments ago" else: from_time = datetime.datetime.fromtimestamp( self.last_update_time ).isoformat() from_time = f", at {from_time}" self._stderr_current(f"{processed}{from_time}")