hasher-matcher-actioner/hmalib/lambdas/matcher.py (82 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import json
import os
import boto3
from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table
from mypy_boto3_sns import SNSClient
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal
from hmalib import metrics
from hmalib.common.logging import get_logger
from hmalib.common.models.bank import BanksTable
from hmalib.matchers.matchers_base import Matcher
from hmalib.common.config import HMAConfig
from hmalib.common.models.pipeline import PipelineHashRecord
INDEXES_BUCKET_NAME = os.environ["INDEXES_BUCKET_NAME"]
BANKS_TABLE = os.environ["BANKS_TABLE"]
DYNAMODB_TABLE = os.environ["DYNAMODB_TABLE"]
HMA_CONFIG_TABLE = os.environ["HMA_CONFIG_TABLE"]
MATCHES_TOPIC_ARN = os.environ["MATCHES_TOPIC_ARN"]
HMAConfig.initialize(HMA_CONFIG_TABLE)
@functools.lru_cache(maxsize=None)
def get_dynamodb() -> DynamoDBServiceResource:
return boto3.resource("dynamodb")
@functools.lru_cache(maxsize=None)
def get_sns_client() -> SNSClient:
return boto3.client("sns")
_matcher = None
def get_matcher(banks_table: BanksTable):
global _matcher
if _matcher is None:
_matcher = Matcher(
index_bucket_name=INDEXES_BUCKET_NAME,
supported_signal_types=[PdqSignal, VideoMD5Signal],
banks_table=banks_table,
)
return _matcher
logger = get_logger(__name__)
def lambda_handler(event, context):
"""
Listens to SQS events fired when new hash is generated. Loads the index
stored in an S3 bucket and looks for a match.
When matched, publishes a notification to an SNS endpoint. Note this is in
contrast with hasher and indexer. They publish to SQS directly. Publishing
to SQS implies there can be only one consumer.
Because, here, in the matcher, we publish to SNS, we can plug multiple
queues behind it and profit!
"""
table = get_dynamodb().Table(DYNAMODB_TABLE)
banks_table = BanksTable(get_dynamodb().Table(BANKS_TABLE))
for sqs_record in event["Records"]:
message = json.loads(sqs_record["body"])
if message.get("Event") == "TestEvent":
logger.debug("Disregarding Test Event")
continue
if not PipelineHashRecord.could_be(message):
logger.warn(
"Could not de-serialize message in matcher lambda. Message was %s",
message,
)
continue
hash_record = PipelineHashRecord.from_sqs_message(message)
logger.info(
"HashRecord for contentId: %s with contentHash: %s",
hash_record.content_id,
hash_record.content_hash,
)
matches = get_matcher(banks_table).match(
hash_record.signal_type, hash_record.content_hash
)
logger.info("Found %d matches.", len(matches))
for match in matches:
get_matcher(banks_table).write_match_record_for_result(
table=table,
signal_type=hash_record.signal_type,
content_hash=hash_record.content_hash,
content_id=hash_record.content_id,
match=match,
)
for match in matches:
get_matcher(banks_table).write_signal_if_not_found(
table=table, signal_type=hash_record.signal_type, match=match
)
if len(matches) != 0:
# Publish all messages together
get_matcher(banks_table).publish_match_message(
content_id=hash_record.content_id,
content_hash=hash_record.content_hash,
matches=matches,
sns_client=get_sns_client(),
topic_arn=MATCHES_TOPIC_ARN,
)
metrics.flush()