hasher-matcher-actioner/hmalib/matchers/matchers_base.py (197 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Implements a unified matcher class. The unified matcher is capable of matching against any index defined in python-threatexchange. """ import datetime from mypy_boto3_sns.client import SNSClient from mypy_boto3_dynamodb.service_resource import Table from threatexchange.signal_type.pdq import PdqSignal from hmalib.common.models.bank import BanksTable from hmalib.common.models.pipeline import MatchRecord import typing as t from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex from threatexchange.signal_type.signal_base import SignalType from hmalib import metrics from hmalib.common.logging import get_logger from hmalib.common.mappings import INDEX_MAPPING from hmalib.common.messages.match import BankedSignal, MatchMessage from hmalib.common.models.signal import ThreatExchangeSignalMetadata from hmalib.indexers.metadata import ( BANKS_SOURCE_SHORT_CODE, THREAT_EXCHANGE_SOURCE_SHORT_CODE, BaseIndexMetadata, ThreatExchangeIndicatorIndexMetadata, BankedSignalIndexMetadata, ) from hmalib.matchers.filters import ( BankActiveFilter, BaseMatchFilter, ThreatExchangePdqMatchDistanceFilter, ThreatExchangePrivacyGroupMatcherActiveFilter, get_max_threshold_of_active_privacy_groups_for_signal_type, ) logger = get_logger(__name__) class Matcher: """ Match against any signal type defined on threatexchange and stored in s3. Once created, indexes used by this are cached on the index. Do not create multiple Matcher instances in the same python runtime for the same signal_type. This would take up more RAM than necessary. Indexes are pulled from S3 on first call for a signal_type. """ def __init__( self, index_bucket_name: str, supported_signal_types: t.List[t.Type[SignalType]], banks_table: BanksTable, ): self.index_bucket_name = index_bucket_name self.supported_signal_types = supported_signal_types self._cached_indexes: t.Dict[t.Type[SignalType], SignalTypeIndex] = {} self.banks_table = banks_table self.match_filters: t.Sequence[BaseMatchFilter] = [ ThreatExchangePrivacyGroupMatcherActiveFilter(), ThreatExchangePdqMatchDistanceFilter(), BankActiveFilter(banks_table=banks_table), ] def match( self, signal_type: t.Type[SignalType], signal_value: str ) -> t.List[IndexMatch[t.List[BaseIndexMetadata]]]: """ Returns MatchMessage which can be directly published to a queue. Note, this also filters out matches that are from datasets that have been de-activated. """ index = self.get_index(signal_type) with metrics.timer(metrics.names.indexer.search_index): match_results: t.List[IndexMatch] = index.query(signal_value) if not match_results: # No matches found in the index return [] return self.filter_match_results(match_results, signal_type) def filter_match_results( self, results: t.List[IndexMatch], signal_type: t.Type[SignalType] ) -> t.List[IndexMatch]: """ For ThreatExchange, use the privacy group's matcher_active flag to filter out match results that should not be returned. If implementing a matcher for something other than threat exchange, consider extending this class and implementing your own. """ # results is a list of match object references that live in any index # this method should not edit those objects directly as they could effect # subsequent calls made while the index is still in memory matches = results.copy() for match_filter in self.match_filters: matches = match_filter.filter_matches(matches, signal_type) return matches def write_match_record_for_result( self, table: Table, signal_type: t.Type[SignalType], content_hash: str, content_id: str, match: IndexMatch[t.List[BaseIndexMetadata]], ): """ Write a match record to dynamodb. The content_id is not important to the matcher. So, the calling lambda is expected to pass on the content_id for match record calls. """ for metadata_obj in match.metadata: match_record_attributes = { "content_id": content_id, "signal_type": signal_type, "content_hash": content_hash, "updated_at": datetime.datetime.now(), "signal_source": metadata_obj.get_source(), "match_distance": int(match.distance), } if metadata_obj.get_source() == THREAT_EXCHANGE_SOURCE_SHORT_CODE: metadata_obj = t.cast( ThreatExchangeIndicatorIndexMetadata, metadata_obj ) match_record_attributes.update( signal_id=metadata_obj.indicator_id, signal_hash=metadata_obj.signal_value, ) elif metadata_obj.get_source() == BANKS_SOURCE_SHORT_CODE: metadata_obj = t.cast(BankedSignalIndexMetadata, metadata_obj) match_record_attributes.update( signal_id=metadata_obj.signal_id, signal_hash=metadata_obj.signal_value, ) MatchRecord(**match_record_attributes).write_to_table(table) @classmethod def write_signal_if_not_found( cls, table: Table, signal_type: t.Type[SignalType], match: IndexMatch, ): """ Write the signal to the datastore. Only signals that have matched are written to the DB. The fetcher takes care of updating the signal with opinions or updates from the source. TODO: Move this out of matchers. This is not matcher specific functionality. Signals could benefit from their own store. Perhaps the API could be useful when building local banks. Who knows! :) """ for signal in cls.get_te_metadata_objects_from_match(signal_type, match): if hasattr(signal, "write_to_table_if_not_found"): # only ThreatExchangeSignalMetadata has this method. # mypy not smart enough to auto cast. signal.write_to_table_if_not_found(table) # type: ignore @classmethod def get_te_metadata_objects_from_match( cls, signal_type: t.Type[SignalType], match: IndexMatch[t.List[BaseIndexMetadata]], ) -> t.List[ThreatExchangeSignalMetadata]: """ See docstring of `write_signal_if_not_found` we will likely want to move this outside of Matcher. However while the MD5 expansion is still on going better to have it all in once place. Note: changes made here will have an effect on api.matches.get_match_for_hash """ metadata_objects = [] for metadata_obj in match.metadata: if metadata_obj.get_source() == THREAT_EXCHANGE_SOURCE_SHORT_CODE: metadata_obj = t.cast( ThreatExchangeIndicatorIndexMetadata, metadata_obj ) metadata_objects.append( ThreatExchangeSignalMetadata( signal_id=str(metadata_obj.indicator_id), privacy_group_id=metadata_obj.privacy_group, updated_at=datetime.datetime.now(), signal_type=signal_type, signal_hash=metadata_obj.signal_value, tags=list(metadata_obj.tags), ) ) return metadata_objects def get_index(self, signal_type: t.Type[SignalType]) -> SignalTypeIndex: """ If cached, return an index instance for the signal_type. If not, build one, cache and return. """ max_custom_threshold = ( get_max_threshold_of_active_privacy_groups_for_signal_type(signal_type) ) index_cls = self._get_index_for_signal_type_matching( signal_type, max_custom_threshold ) # Check for signal_type in cache AND confirm said index class type is # still correct for the given [optional] max_custom_threshold if not signal_type in self._cached_indexes or not isinstance( self._cached_indexes[signal_type], index_cls ): with metrics.timer(metrics.names.indexer.download_index): self._cached_indexes[signal_type] = index_cls.load( bucket_name=self.index_bucket_name ) return self._cached_indexes[signal_type] @classmethod def _get_index_for_signal_type_matching( cls, signal_type: t.Type[SignalType], max_custom_threshold: int ): indexes = INDEX_MAPPING[signal_type] # disallow empty list assert indexes if len(indexes) == 1: # if we only have one option just return return indexes[0] indexes.sort(key=lambda i: i.get_index_max_distance()) for index_cls in indexes: if max_custom_threshold <= index_cls.get_index_max_distance(): return index_cls # if we don't have an index that supports max threshold # just return the one if the highest possible max distance return indexes[-1] def publish_match_message( self, content_id: str, content_hash: str, matches: t.List[IndexMatch[t.List[BaseIndexMetadata]]], sns_client: SNSClient, topic_arn: str, ): """ Creates banked signal objects and publishes one message for a list of matches to SNS. """ banked_signals = [] for match in matches: for metadata_obj in match.metadata: if metadata_obj.get_source() == THREAT_EXCHANGE_SOURCE_SHORT_CODE: metadata_obj = t.cast( ThreatExchangeIndicatorIndexMetadata, metadata_obj ) banked_signal = BankedSignal( str(metadata_obj.indicator_id), str(metadata_obj.privacy_group), str(metadata_obj.get_source()), ) for tag in metadata_obj.tags: banked_signal.add_classification(tag) banked_signals.append(banked_signal) elif metadata_obj.get_source() == BANKS_SOURCE_SHORT_CODE: metadata_obj = t.cast(BankedSignalIndexMetadata, metadata_obj) bank_member = self.banks_table.get_bank_member( bank_member_id=metadata_obj.bank_member_id ) banked_signal = BankedSignal( metadata_obj.bank_member_id, bank_member.bank_id, metadata_obj.get_source(), ) # TODO: This would do good with caching. bank = self.banks_table.get_bank(bank_id=bank_member.bank_id) for tag in set.union(bank_member.bank_member_tags, bank.bank_tags): banked_signal.add_classification(tag) banked_signals.append(banked_signal) match_message = MatchMessage( content_key=content_id, content_hash=content_hash, matching_banked_signals=banked_signals, ) sns_client.publish(TopicArn=topic_arn, Message=match_message.to_aws_json())