hasher-matcher-actioner/hmalib/lambdas/unified_indexer.py (102 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import typing as t
import functools
import boto3
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal
from threatexchange.signal_type.signal_base import SignalType
from hmalib.common.models.bank import BanksTable
from hmalib.common.s3_adapters import (
HashRowT,
ThreatExchangeS3Adapter,
ThreatExchangeS3PDQAdapter,
ThreatExchangeS3VideoMD5Adapter,
S3ThreatDataConfig,
)
from hmalib.common.logging import get_logger
from hmalib import metrics
from hmalib.common.mappings import INDEX_MAPPING
from hmalib.indexers.metadata import BankedSignalIndexMetadata
from hmalib.indexers.s3_indexers import (
S3BackedInstrumentedIndexMixin,
)
logger = get_logger(__name__)
dynamodb = boto3.resource("dynamodb")
THREAT_EXCHANGE_DATA_BUCKET_NAME = os.environ["THREAT_EXCHANGE_DATA_BUCKET_NAME"]
THREAT_EXCHANGE_DATA_FOLDER = os.environ["THREAT_EXCHANGE_DATA_FOLDER"]
INDEXES_BUCKET_NAME = os.environ["INDEXES_BUCKET_NAME"]
BANKS_TABLE = os.environ["BANKS_TABLE"]
def get_all_bank_hash_rows(
signal_type: t.Type[SignalType], banks_table: BanksTable
) -> t.Iterable[HashRowT]:
"""
Make repeated calls to banks table to get all hashes for a signal type.
Returns list[HashRowT]. HashRowT is a tuple of hash_value and some metadata
about the signal.
"""
exclusive_start_key = None
hash_rows: t.List[HashRowT] = []
while True:
page = banks_table.get_bank_member_signals_to_process_page(
signal_type=signal_type, exclusive_start_key=exclusive_start_key
)
for bank_member_signal in page.items:
hash_rows.append(
(
bank_member_signal.signal_value,
[
BankedSignalIndexMetadata(
bank_member_signal.signal_id,
bank_member_signal.signal_value,
bank_member_signal.bank_member_id,
),
],
)
)
exclusive_start_key = page.last_evaluated_key
if not page.has_next_page():
break
logger.info(
f"Obtained {len(hash_rows)} hash records from banks for signal_type:{signal_type.get_name()}"
)
return hash_rows
def merge_hash_rows_on_hash_value(
accumulator: t.Dict[str, HashRowT], hash_row: HashRowT
) -> t.Dict[str, HashRowT]:
hash, metadata = hash_row
if hash not in accumulator.keys():
# Add hash as new row
accumulator[hash] = hash_row
else:
# Append current metadata to existing row's metadata objects by
# replacing completely. Tuples can't be updated, so replace.
accumulator[hash] = (hash, list(metadata) + list(accumulator[hash][1]))
return accumulator
# Maps from signal type to the subclass of ThreatExchangeS3Adapter.
# ThreatExchangeS3Adapter is used to fetch all the data corresponding to a
# signal_type. At some point, we must allow _updates_ to indexes rather than
# rebuilding them all the time.
_ADAPTER_MAPPING: t.Dict[t.Type[SignalType], t.Type[ThreatExchangeS3Adapter]] = {
PdqSignal: ThreatExchangeS3PDQAdapter,
VideoMD5Signal: ThreatExchangeS3VideoMD5Adapter,
}
# Which signal types must be processed into an index?
ALL_INDEXABLE_SIGNAL_TYPES = [PdqSignal, VideoMD5Signal]
def lambda_handler(event, context):
"""
Runs on a schedule. On each run, gets all data files for
ALL_INDEXABLE_SIGNAL_TYPES from s3, converts the raw data file into an index
and writes to an output S3 bucket.
As per the default configuration, the bucket must be
- the hashing data bucket eg. dipanjanm-hashing-<...>
- the key name must be in the ThreatExchange folder (eg.
threat_exchange_data/)
- the key name must return a signal_type in
ThreatUpdateS3Store.get_signal_type_from_object_key
"""
# Note: even though we know which files were updated, threatexchange indexes
# do not yet allow adding new entries. So, we must do a full rebuild. So, we
# only end up using the signal types that were updated, not the actual files
# that changed.
s3_config = S3ThreatDataConfig(
threat_exchange_data_bucket_name=THREAT_EXCHANGE_DATA_BUCKET_NAME,
threat_exchange_data_folder=THREAT_EXCHANGE_DATA_FOLDER,
)
banks_table = BanksTable(dynamodb.Table(BANKS_TABLE))
for signal_type in ALL_INDEXABLE_SIGNAL_TYPES:
adapter_class = _ADAPTER_MAPPING[signal_type]
data_files = adapter_class(
config=s3_config, metrics_logger=metrics.names.indexer
).load_data()
with metrics.timer(metrics.names.indexer.get_bank_data):
bank_data = get_all_bank_hash_rows(signal_type, banks_table)
with metrics.timer(metrics.names.indexer.merge_datafiles):
logger.info(f"Merging {signal_type} Hash files")
# go from dict[filename, list<hash rows>] → list<hash rows>
flattened_data = [
hash_row for file_ in data_files.values() for hash_row in file_
]
merged_data = functools.reduce(
merge_hash_rows_on_hash_value, flattened_data + bank_data, {}
).values()
with metrics.timer(metrics.names.indexer.build_index):
logger.info(f"Rebuilding {signal_type} Index")
for index_class in INDEX_MAPPING[signal_type]:
index: S3BackedInstrumentedIndexMixin = index_class.build(merged_data)
logger.info(
f"Putting {signal_type} index in S3 for index {index.get_index_class_name()}"
)
index.save(bucket_name=INDEXES_BUCKET_NAME)
metrics.flush()
logger.info("Index updates complete")