python-threatexchange/threatexchange/cli/dataset_cmd.py (205 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from argparse import ArgumentParser import collections import csv import sys import typing as t import logging from threatexchange.signal_type.signal_base import SignalType from threatexchange import signal_type from threatexchange.cli.cli_config import CLISettings from threatexchange.content_type.content_base import ContentType from threatexchange.fetcher.collab_config import CollaborationConfigBase from threatexchange.fetcher.fetch_state import FetchedSignalMetadata from threatexchange.cli import command_base from threatexchange import common class DatasetCommand(command_base.Command): """ Introspect fetched data. Can print out contents in simple formats (ideal for sending to another system), or regenerate index files (ideal if distributing indices for some reason) """ @classmethod def init_argparse(cls, settings: CLISettings, ap: ArgumentParser) -> None: actions = ap.add_mutually_exclusive_group() actions.add_argument( "--rebuild-indices", "-r", action="store_true", help="rebuild indices from fetched data", ) actions.add_argument( "--clear-indices", "-X", action="store_true", help="clear all indices", ) actions.add_argument( "--signal-summary", "-S", action="store_true", help="print summary in terms of signals", ) actions.add_argument( "--print-records", "-P", action="store_true", help="print records to screen", ) type_selector = ap.add_mutually_exclusive_group() type_selector.add_argument( "--only-signals", "-s", nargs="+", default=[], type=common.argparse_choices_pre_type( choices=[s.get_name() for s in settings.get_all_signal_types()], type=settings.get_signal_type, ), help="only process these sigals", ) type_selector.add_argument( "--only-content", "-C", nargs="+", default=[], type=common.argparse_choices_pre_type( choices=[s.get_name() for s in settings.get_all_content_types()], type=settings.get_content_type, ), help="only process signals for these content types", ) ap.add_argument( "--only-collabs", "-c", nargs="+", default=[], metavar="NAME", help="[-S|-P] only count items with this tag", ) ap.add_argument( "--only-tags", "-t", default=[], metavar="STR", help="[-S|-P] only count items with these tags", ) ap.add_argument( "--signals-only", "-i", action="store_true", help="[-P] only print signals", ) ap.add_argument( "--limit", "-l", action="store_true", help="[-P] only print this many records", ) def __init__( # These all have defaults to make it easier to call # only for rebuld self, # Mode clear_indices: bool = False, rebuild_indices: bool = False, signal_summary: bool = False, print_records: bool = False, # Signal selectors only_collabs: t.Sequence[str] = (), only_signals: t.Sequence[t.Type[SignalType]] = (), only_content: t.Sequence[t.Type[ContentType]] = (), only_tags: t.Sequence[str] = (), # Print stuff signals_only: bool = False, limit: t.Optional[int] = None, ) -> None: self.clear_indices = clear_indices self.rebuild_indices = rebuild_indices self.print_records = print_records self.signal_summary = signal_summary self.only_collabs = set(only_collabs) self.only_signals = set(only_signals) self.only_content = set(only_content) self.only_tags = set(only_tags) self.signals_only = signals_only self.limit = limit def execute(self, settings: CLISettings) -> None: # Maybe consider subcommands? if self.clear_indices: self.execute_clear_indices(settings) elif self.rebuild_indices: self.execute_generate_indices(settings) elif self.print_records: self.execute_print_records(settings) elif self.signal_summary: self.execute_print_signal_summary(settings) else: self.execute_print_summary(settings) def get_signal_types(self, settings: CLISettings) -> t.Set[t.Type[SignalType]]: signal_types = self.only_signals or settings.get_all_signal_types() if self.only_content: signal_types = [ s for s in signal_types if any(c in self.only_content for c in s.get_content_types()) ] return set(signal_types) def get_collabs(self, settings: CLISettings) -> t.List[CollaborationConfigBase]: collabs = [ c for c in settings.get_all_collabs(default_to_sample=True) if c.enabled ] if self.only_collabs: collabs = [c for c in collabs if c.name in self.only_collabs] return collabs def get_signals( self, settings: CLISettings, signal_types: t.Iterable[t.Type[SignalType]] ) -> t.Dict[ t.Type[SignalType], t.Dict[str, t.List[t.Tuple[str, FetchedSignalMetadata]]] ]: collabs = self.get_collabs(settings) collab_by_api: t.Dict[str, t.List[CollaborationConfigBase]] = {} for collab_config in collabs: collab_by_api.setdefault(collab_config.api, []).append(collab_config) by_type = {} for s_type in signal_types: by_signal: t.Dict[ str, t.List[t.Tuple[str, FetchedSignalMetadata]], ] = {} for collabs_for_store in collab_by_api.values(): store = settings.get_fetch_store_for_collab(collabs_for_store[0]) by_collab = store.get_for_signal_type(collabs_for_store, s_type) for collab, signals in by_collab.items(): for signal, record in signals.items(): if self.only_tags: for opinion in record.get_as_opinions(): if any(t in self.only_tags for t in opinion.tags): break else: continue by_signal.setdefault(signal, []).append((collab, record)) by_type[s_type] = by_signal return by_type def execute_print_summary(self, settings: CLISettings): signals = self.get_signals(settings, self.get_signal_types(settings)) by_type: t.Dict[str, int] = collections.Counter() for s_type, type_signals in signals.items(): by_type[s_type.get_name()] += len(type_signals) for s_name, count in sorted(by_type.items(), key=lambda i: -i[1]): self.stderr(f"{s_name}: {count}") def execute_print_signal_summary(self, settings): raise NotImplementedError # signal_types = meta.get_signal_types_by_name() # by_signal: t.Dict[str, int] = collections.Counter() # for indicator in indicators.values(): # for name, signal_type in signal_types.items(): # if signal_type.indicator_applies( # indicator.indicator_type, list(indicator.rollup.labels) # ): # by_signal[name] += 1 # for name, count in sorted(by_signal.items(), key=lambda i: -i[1]): # self.stderr(f"{name}: {count}") def execute_print_records(self, settings): raise NotImplementedError # csv_writer = csv.writer(sys.stdout) # for indicator in indicators.values(): # if self.indicator_only: # print(indicator.indicator) # else: # csv_writer.writerow(indicator.as_csv_row()) def execute_clear_indices(self, settings: CLISettings) -> None: only_signals = None if self.only_signals or self.only_content: only_signals = self.get_signal_types(settings) settings.index_store.clear(only_signals) def execute_generate_indices(self, settings: CLISettings) -> None: signal_types = self.get_signal_types(settings) for s_type in signal_types: index_cls = s_type.get_index_cls() signal_by_type = self.get_signals(settings, [s_type]) signals = signal_by_type.get(s_type, {}) if not signals: logging.info("No signals for %s", s_type.__name__) settings.index_store.clear([s_type]) continue self.stderr( "Building index for", s_type.get_name(), f"with {len(signals)} signals...", ) index = index_cls.build(signals.items()) settings.index_store.store_index(s_type, index) self.stderr(f"Index for {s_type.get_name()} ready")