python-threatexchange/threatexchange/cli/dataset_cmd.py (205 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from argparse import ArgumentParser
import collections
import csv
import sys
import typing as t
import logging
from threatexchange.signal_type.signal_base import SignalType
from threatexchange import signal_type
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.fetcher.collab_config import CollaborationConfigBase
from threatexchange.fetcher.fetch_state import FetchedSignalMetadata
from threatexchange.cli import command_base
from threatexchange import common
class DatasetCommand(command_base.Command):
"""
Introspect fetched data.
Can print out contents in simple formats
(ideal for sending to another system), or regenerate
index files (ideal if distributing indices for some reason)
"""
@classmethod
def init_argparse(cls, settings: CLISettings, ap: ArgumentParser) -> None:
actions = ap.add_mutually_exclusive_group()
actions.add_argument(
"--rebuild-indices",
"-r",
action="store_true",
help="rebuild indices from fetched data",
)
actions.add_argument(
"--clear-indices",
"-X",
action="store_true",
help="clear all indices",
)
actions.add_argument(
"--signal-summary",
"-S",
action="store_true",
help="print summary in terms of signals",
)
actions.add_argument(
"--print-records",
"-P",
action="store_true",
help="print records to screen",
)
type_selector = ap.add_mutually_exclusive_group()
type_selector.add_argument(
"--only-signals",
"-s",
nargs="+",
default=[],
type=common.argparse_choices_pre_type(
choices=[s.get_name() for s in settings.get_all_signal_types()],
type=settings.get_signal_type,
),
help="only process these sigals",
)
type_selector.add_argument(
"--only-content",
"-C",
nargs="+",
default=[],
type=common.argparse_choices_pre_type(
choices=[s.get_name() for s in settings.get_all_content_types()],
type=settings.get_content_type,
),
help="only process signals for these content types",
)
ap.add_argument(
"--only-collabs",
"-c",
nargs="+",
default=[],
metavar="NAME",
help="[-S|-P] only count items with this tag",
)
ap.add_argument(
"--only-tags",
"-t",
default=[],
metavar="STR",
help="[-S|-P] only count items with these tags",
)
ap.add_argument(
"--signals-only",
"-i",
action="store_true",
help="[-P] only print signals",
)
ap.add_argument(
"--limit",
"-l",
action="store_true",
help="[-P] only print this many records",
)
def __init__(
# These all have defaults to make it easier to call
# only for rebuld
self,
# Mode
clear_indices: bool = False,
rebuild_indices: bool = False,
signal_summary: bool = False,
print_records: bool = False,
# Signal selectors
only_collabs: t.Sequence[str] = (),
only_signals: t.Sequence[t.Type[SignalType]] = (),
only_content: t.Sequence[t.Type[ContentType]] = (),
only_tags: t.Sequence[str] = (),
# Print stuff
signals_only: bool = False,
limit: t.Optional[int] = None,
) -> None:
self.clear_indices = clear_indices
self.rebuild_indices = rebuild_indices
self.print_records = print_records
self.signal_summary = signal_summary
self.only_collabs = set(only_collabs)
self.only_signals = set(only_signals)
self.only_content = set(only_content)
self.only_tags = set(only_tags)
self.signals_only = signals_only
self.limit = limit
def execute(self, settings: CLISettings) -> None:
# Maybe consider subcommands?
if self.clear_indices:
self.execute_clear_indices(settings)
elif self.rebuild_indices:
self.execute_generate_indices(settings)
elif self.print_records:
self.execute_print_records(settings)
elif self.signal_summary:
self.execute_print_signal_summary(settings)
else:
self.execute_print_summary(settings)
def get_signal_types(self, settings: CLISettings) -> t.Set[t.Type[SignalType]]:
signal_types = self.only_signals or settings.get_all_signal_types()
if self.only_content:
signal_types = [
s
for s in signal_types
if any(c in self.only_content for c in s.get_content_types())
]
return set(signal_types)
def get_collabs(self, settings: CLISettings) -> t.List[CollaborationConfigBase]:
collabs = [
c for c in settings.get_all_collabs(default_to_sample=True) if c.enabled
]
if self.only_collabs:
collabs = [c for c in collabs if c.name in self.only_collabs]
return collabs
def get_signals(
self, settings: CLISettings, signal_types: t.Iterable[t.Type[SignalType]]
) -> t.Dict[
t.Type[SignalType], t.Dict[str, t.List[t.Tuple[str, FetchedSignalMetadata]]]
]:
collabs = self.get_collabs(settings)
collab_by_api: t.Dict[str, t.List[CollaborationConfigBase]] = {}
for collab_config in collabs:
collab_by_api.setdefault(collab_config.api, []).append(collab_config)
by_type = {}
for s_type in signal_types:
by_signal: t.Dict[
str,
t.List[t.Tuple[str, FetchedSignalMetadata]],
] = {}
for collabs_for_store in collab_by_api.values():
store = settings.get_fetch_store_for_collab(collabs_for_store[0])
by_collab = store.get_for_signal_type(collabs_for_store, s_type)
for collab, signals in by_collab.items():
for signal, record in signals.items():
if self.only_tags:
for opinion in record.get_as_opinions():
if any(t in self.only_tags for t in opinion.tags):
break
else:
continue
by_signal.setdefault(signal, []).append((collab, record))
by_type[s_type] = by_signal
return by_type
def execute_print_summary(self, settings: CLISettings):
signals = self.get_signals(settings, self.get_signal_types(settings))
by_type: t.Dict[str, int] = collections.Counter()
for s_type, type_signals in signals.items():
by_type[s_type.get_name()] += len(type_signals)
for s_name, count in sorted(by_type.items(), key=lambda i: -i[1]):
self.stderr(f"{s_name}: {count}")
def execute_print_signal_summary(self, settings):
raise NotImplementedError
# signal_types = meta.get_signal_types_by_name()
# by_signal: t.Dict[str, int] = collections.Counter()
# for indicator in indicators.values():
# for name, signal_type in signal_types.items():
# if signal_type.indicator_applies(
# indicator.indicator_type, list(indicator.rollup.labels)
# ):
# by_signal[name] += 1
# for name, count in sorted(by_signal.items(), key=lambda i: -i[1]):
# self.stderr(f"{name}: {count}")
def execute_print_records(self, settings):
raise NotImplementedError
# csv_writer = csv.writer(sys.stdout)
# for indicator in indicators.values():
# if self.indicator_only:
# print(indicator.indicator)
# else:
# csv_writer.writerow(indicator.as_csv_row())
def execute_clear_indices(self, settings: CLISettings) -> None:
only_signals = None
if self.only_signals or self.only_content:
only_signals = self.get_signal_types(settings)
settings.index_store.clear(only_signals)
def execute_generate_indices(self, settings: CLISettings) -> None:
signal_types = self.get_signal_types(settings)
for s_type in signal_types:
index_cls = s_type.get_index_cls()
signal_by_type = self.get_signals(settings, [s_type])
signals = signal_by_type.get(s_type, {})
if not signals:
logging.info("No signals for %s", s_type.__name__)
settings.index_store.clear([s_type])
continue
self.stderr(
"Building index for",
s_type.get_name(),
f"with {len(signals)} signals...",
)
index = index_cls.build(signals.items())
settings.index_store.store_index(s_type, index)
self.stderr(f"Index for {s_type.get_name()} ready")