hasher-matcher-actioner/hmalib/lambdas/actions/action_evaluator.py (156 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import boto3
import json
import os
import typing as t
from dataclasses import dataclass
from functools import lru_cache
from threatexchange.content_type.photo import PhotoContent
from hmalib.common.logging import get_logger
from hmalib.common.classification_models import (
BankIDClassificationLabel,
BankSourceClassificationLabel,
BankedContentIDClassificationLabel,
ClassificationLabel,
SubmittedContentClassificationLabel,
WritebackTypes,
Label,
)
from hmalib.common.config import HMAConfig
from hmalib.common.configs.evaluator import (
ActionLabel,
ActionRule,
)
from hmalib.common.messages.action import ActionMessage
from hmalib.common.messages.match import BankedSignal, MatchMessage
from hmalib.common.messages.writeback import WritebackMessage
from hmalib.common.models.content import ContentObject
from mypy_boto3_sqs import SQSClient
from mypy_boto3_dynamodb.service_resource import Table, DynamoDBServiceResource
logger = get_logger(__name__)
@dataclass
class ActionEvaluatorConfig:
"""
Simple holder for getting typed environment variables
"""
actions_queue_url: str
sqs_client: SQSClient
dynamo_db_table: Table
writeback_queue_url: str
@classmethod
@lru_cache(maxsize=None)
def get(cls):
logger.info(
"Initializing configs using table name %s", os.environ["CONFIG_TABLE_NAME"]
)
logger.info(
"Initializing dynamo table using table name %s",
os.environ["DYNAMODB_TABLE"],
)
HMAConfig.initialize(os.environ["CONFIG_TABLE_NAME"])
dynamo_db_table_name = os.environ["DYNAMODB_TABLE"]
dynamodb: DynamoDBServiceResource = boto3.resource("dynamodb")
writeback_queue_url = os.environ["WRITEBACKS_QUEUE_URL"]
return cls(
actions_queue_url=os.environ["ACTIONS_QUEUE_URL"],
sqs_client=boto3.client("sqs"),
dynamo_db_table=dynamodb.Table(dynamo_db_table_name),
writeback_queue_url=writeback_queue_url,
)
def lambda_handler(event, context):
"""
This lambda is called when one or more matches are found. If a single hash matches
multiple datasets, this will be called only once.
Action labels are generated for each match message, then an action is performed
corresponding to each action label.
"""
config = ActionEvaluatorConfig.get()
for sqs_record in event["Records"]:
# TODO research max # sqs records / lambda_handler invocation
sqs_record_body = json.loads(sqs_record["body"])
match_message = MatchMessage.from_aws_json(sqs_record_body["Message"])
logger.info("Evaluating match_message: %s", match_message)
action_rules = get_action_rules()
logger.info("Evaluating against action_rules: %s", action_rules)
submitted_content = ContentObject.get_from_content_id(
config.dynamo_db_table, match_message.content_key
)
action_label_to_action_rules = get_actions_to_take(
match_message,
action_rules,
submitted_content.additional_fields,
)
action_labels = list(action_label_to_action_rules.keys())
for action_label in action_labels:
action_message = ActionMessage.from_match_message_action_label_action_rules_and_additional_fields(
match_message,
action_label,
action_label_to_action_rules[action_label],
list(submitted_content.additional_fields),
)
logger.info("Sending Action message: %s", action_message)
config.sqs_client.send_message(
QueueUrl=config.actions_queue_url,
MessageBody=action_message.to_aws_json(),
)
writeback_message = WritebackMessage.from_match_message_and_type(
match_message, WritebackTypes.SawThisToo
)
writeback_message.send_to_queue(config.sqs_client, config.writeback_queue_url)
return {"evaluation_completed": "true"}
def get_actions_to_take(
match_message: MatchMessage,
action_rules: t.List[ActionRule],
additional_fields_on_content: t.Set[str],
) -> t.Dict[ActionLabel, t.List[ActionRule]]:
"""
Returns action labels for each action rule that applies to a match message.
"""
action_label_to_action_rules: t.Dict[ActionLabel, t.List[ActionRule]] = dict()
content_classifications = {
SubmittedContentClassificationLabel(field)
for field in additional_fields_on_content
}
logger.info(
"Adding SubmittedContentClassificationLabel(s): %s", content_classifications
)
for banked_signal in match_message.matching_banked_signals:
for action_rule in action_rules:
if action_rule_applies_to_classifications(
action_rule,
banked_signal.classifications.union(content_classifications),
):
if action_rule.action_label in action_label_to_action_rules:
action_label_to_action_rules[action_rule.action_label].append(
action_rule
)
else:
action_label_to_action_rules[action_rule.action_label] = [
action_rule
]
action_label_to_action_rules = remove_superseded_actions(
action_label_to_action_rules
)
return action_label_to_action_rules
def get_action_rules() -> t.List[ActionRule]:
"""
TODO Research caching rules for a short bit of time (1 min? 5 min?) use @lru_cache to implement
Returns the ActionRule objects stored in the config repository. Each ActionRule
will have the following attributes: MustHaveLabels, MustNotHaveLabels, ActionLabel.
"""
return ActionRule.get_all()
def action_rule_applies_to_classifications(
action_rule: ActionRule, classifications: t.Set[Label]
) -> bool:
"""
Evaluate if the action rule applies to the classifications. Return True if the action rule's "must have"
labels are all present and none of the "must not have" labels are present in the classifications, otherwise return False.
"""
return action_rule.must_have_labels.issubset(
classifications
) and action_rule.must_not_have_labels.isdisjoint(classifications)
def remove_superseded_actions(
action_label_to_action_rules: t.Dict[ActionLabel, t.List[ActionRule]],
) -> t.Dict[ActionLabel, t.List[ActionRule]]:
"""
TODO implement
Evaluates a dictionary of action labels and the associated action rules generated for
a match message against the actions. Action labels that are superseded by another will
be removed.
"""
return action_label_to_action_rules
if __name__ == "__main__":
# For basic debugging
HMAConfig.initialize(os.environ["CONFIG_TABLE_NAME"])
action_rules = get_action_rules()
match_message = MatchMessage(
content_key="m2",
content_hash="361da9e6cf1b72f5cea0344e5bb6e70939f4c70328ace762529cac704297354a",
matching_banked_signals=[
BankedSignal(
banked_content_id="3070359009741438",
bank_id="258601789084078",
bank_source="te",
classifications={
BankedContentIDClassificationLabel(value="258601789084078"),
ClassificationLabel(value="true_positive"),
BankSourceClassificationLabel(value="te"),
BankIDClassificationLabel(value="3534976909868947"),
},
)
],
)
event = {
"Records": [{"body": json.dumps({"Message": match_message.to_aws_json()})}]
}
lambda_handler(event, None)