hasher-matcher-actioner/hmalib/common/models/pipeline.py (272 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import datetime from decimal import Decimal import typing as t from dataclasses import dataclass, field from mypy_boto3_dynamodb.service_resource import Table from boto3.dynamodb.conditions import Key from threatexchange.content_type.meta import get_signal_types_by_name from threatexchange.signal_type.signal_base import SignalType from hmalib.common.timebucketizer import CSViable from hmalib.common.models.models_base import ( DynamoDBItem, DynamoDBCursorKey, PaginatedResponse, ) """ Data transfer object classes to be used with dynamodbstore Classes in this module should implement methods `to_dynamodb_item(self)` and `to_sqs_message(self)` """ @dataclass class PipelineRecordBase(DynamoDBItem): """ Base Class for records of pieces of content going through the hashing/matching pipeline. """ content_id: str signal_type: t.Type[SignalType] content_hash: str updated_at: datetime.datetime def to_dynamodb_item(self) -> dict: raise NotImplementedError def to_sqs_message(self) -> dict: raise NotImplementedError @classmethod def get_recent_items_page( cls, table: Table, ExclusiveStartKey: t.Optional[DynamoDBCursorKey] = None ) -> PaginatedResponse: """ Get a paginated list of recent items. The API is purposefully kept """ raise NotImplementedError @dataclass class PipelineRecordDefaultsBase: """ Hash and match records may have signal_type specific attributes that are not universal. eg. PDQ hashes have quality and PDQ matches have distance while MD5 has neither. Assuming such signal_type specific attributes will not be indexed, we are choosing to put them into a bag of variables. See PipelineRecordBase.[de]serialize_signal_specific_attributes() to understand storage. Ideally, this would be an attribute with defaults, but that would make inheritance complicated because default_values would precede non-default values in the sub class. """ signal_specific_attributes: t.Dict[str, t.Union[int, float, str]] = field( default_factory=dict ) def serialize_signal_specific_attributes(self) -> dict: """ Converts signal_specific_attributes into a dict. Uses the signal_type as a prefix. So for PDQ hash records, `item.signal_specific_attributes.quality` will become `item.pdq_quality`. Storing as top-level item attributes allows indexing if we need it later. You can't do that with nested elements. """ # Note on Typing: PipelineRecordDefaultsBase is meant to be used with # PipelineRecordBase. So it will have access to all fields from # PipelineRecordBase. This is (impossible?) to express using mypy. So # ignore self.signal_type return { f"{self.signal_type.get_name()}_{key}": value # type:ignore for key, value in self.signal_specific_attributes.items() } @staticmethod def _signal_specific_attribute_remove_prefix(prefix: str, k: str) -> str: return k[len(prefix) :] @classmethod def deserialize_signal_specific_attributes( cls, d: t.Dict[str, t.Any] ) -> t.Dict[str, t.Union[int, float, str]]: """ Reverses serialize_signal_specific_attributes. """ signal_type = d["SignalType"] signal_type_prefix = f"{signal_type}_" return { cls._signal_specific_attribute_remove_prefix(signal_type_prefix, key): value for key, value in d.items() if key.startswith(signal_type_prefix) } @dataclass class PipelineHashRecord(PipelineRecordDefaultsBase, PipelineRecordBase): """ Successful execution at the hasher produces this record. """ def to_dynamodb_item(self) -> dict: top_level_overrides = self.serialize_signal_specific_attributes() return dict( **top_level_overrides, **{ "PK": self.get_dynamodb_content_key(self.content_id), "SK": self.get_dynamodb_type_key(self.signal_type.get_name()), "ContentHash": self.content_hash, "SignalType": self.signal_type.get_name(), "GSI2-PK": self.get_dynamodb_type_key(self.__class__.__name__), "UpdatedAt": self.updated_at.isoformat(), }, ) def to_legacy_sqs_message(self) -> dict: """ Prior to supporting MD5, the hash message was simplistic and did not support all fields in the PipelineHashRecord. This is inconsistent with almost all other message models. We can remove this once pdq_hasher and pdq_matcher are removed. """ return { "hash": self.content_hash, "type": self.signal_type.get_name(), "key": self.content_id, } def to_sqs_message(self) -> dict: return { "ContentId": self.content_id, "SignalType": self.signal_type.get_name(), "ContentHash": self.content_hash, "SignalSpecificAttributes": self.signal_specific_attributes, "UpdatedAt": self.updated_at.isoformat(), } @classmethod def from_sqs_message(cls, d: dict) -> "PipelineHashRecord": return cls( content_id=d["ContentId"], signal_type=get_signal_types_by_name()[d["SignalType"]], content_hash=d["ContentHash"], signal_specific_attributes=d["SignalSpecificAttributes"], updated_at=datetime.datetime.fromisoformat(d["UpdatedAt"]), ) @classmethod def could_be(cls, d: dict) -> bool: """ Return True if this dict can be converted to a PipelineHashRecord """ return "ContentId" in d and "SignalType" in d and "ContentHash" in d @classmethod def get_from_content_id( cls, table: Table, content_id: str, signal_type: t.Optional[t.Type[SignalType]] = None, ) -> t.List["PipelineHashRecord"]: """ Returns all available PipelineHashRecords for a content_id. """ expected_pk = cls.get_dynamodb_content_key(content_id) if signal_type is None: condition_expression = Key("PK").eq(expected_pk) & Key("SK").begins_with( DynamoDBItem.TYPE_PREFIX ) else: condition_expression = Key("PK").eq(expected_pk) & Key("SK").eq( DynamoDBItem.get_dynamodb_type_key(signal_type.get_name()) ) return cls._result_items_to_records( table.query( KeyConditionExpression=condition_expression, ).get("Items", []) ) @classmethod def get_recent_items_page( cls, table: Table, exclusive_start_key: t.Optional[DynamoDBCursorKey] = None ) -> PaginatedResponse["PipelineHashRecord"]: """ Get a paginated list of recent items. """ if not exclusive_start_key: # Evidently, https://github.com/boto/boto3/issues/2813 boto is able # to distinguish fun(Parameter=None) from fun(). So, we can't use # exclusive_start_key's optionality. We have to do an if clause! # Fun! result = table.query( IndexName="GSI-2", ScanIndexForward=False, Limit=100, KeyConditionExpression=Key("GSI2-PK").eq( DynamoDBItem.get_dynamodb_type_key(cls.__name__) ), ) else: result = table.query( IndexName="GSI-2", ExclusiveStartKey=exclusive_start_key, ScanIndexForward=False, Limit=100, KeyConditionExpression=Key("GSI2-PK").eq( DynamoDBItem.get_dynamodb_type_key(cls.__name__) ), ) return PaginatedResponse( t.cast(DynamoDBCursorKey, result.get("LastEvaluatedKey", None)), cls._result_items_to_records(result["Items"]), ) @classmethod def _result_items_to_records( cls, items: t.List[t.Dict], ) -> t.List["PipelineHashRecord"]: """ Get a paginated list of recent hash records. Subsequent calls must use `return_value.last_evaluated_key`. """ return [ PipelineHashRecord( content_id=item["PK"][len(cls.CONTENT_KEY_PREFIX) :], signal_type=get_signal_types_by_name()[item["SignalType"]], content_hash=item["ContentHash"], updated_at=datetime.datetime.fromisoformat(item["UpdatedAt"]), signal_specific_attributes=cls.deserialize_signal_specific_attributes( item ), ) for item in items ] @dataclass class _MatchRecord(PipelineRecordBase): """ Successful execution at the matcher produces this record. """ signal_id: str signal_source: str signal_hash: str match_distance: t.Optional[int] = None @dataclass class MatchRecord(PipelineRecordDefaultsBase, _MatchRecord): """ Weird, innit? You can't introduce non-default fields after default fields. All default fields in PipelineRecordBase are actually in PipelineRecordDefaultsBase and this complex inheritance chain allows you to create an MRO that is legal. H/T: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses """ def to_dynamodb_item(self) -> dict: top_level_overrides = self.serialize_signal_specific_attributes() return dict( **top_level_overrides, **{ "PK": self.get_dynamodb_content_key(self.content_id), "SK": self.get_dynamodb_signal_key(self.signal_source, self.signal_id), "ContentHash": self.content_hash, "UpdatedAt": self.updated_at.isoformat(), "SignalHash": self.signal_hash, "SignalSource": self.signal_source, "SignalType": self.signal_type.get_name(), "GSI1-PK": self.get_dynamodb_signal_key( self.signal_source, self.signal_id ), "GSI1-SK": self.get_dynamodb_content_key(self.content_id), "HashType": self.signal_type.get_name(), "GSI2-PK": self.get_dynamodb_type_key(self.__class__.__name__), "MatchDistance": self.match_distance, }, ) def to_sqs_message(self) -> dict: # TODO add method for when matches are added to a sqs raise NotImplementedError @classmethod def get_from_content_id( cls, table: Table, content_id: str ) -> t.List["MatchRecord"]: """ Return all matches for a content_id. """ content_key = DynamoDBItem.get_dynamodb_content_key(content_id) source_prefix = DynamoDBItem.SIGNAL_KEY_PREFIX return cls._result_items_to_records( table.query( KeyConditionExpression=Key("PK").eq(content_key) & Key("SK").begins_with(source_prefix), ).get("Items", []) ) @classmethod def get_from_signal( cls, table: Table, signal_id: t.Union[str, int], signal_source: str ) -> t.List["MatchRecord"]: """ Return all matches for a signal. Needs source and id to uniquely identify a signal. """ signal_key = DynamoDBItem.get_dynamodb_signal_key(signal_source, signal_id) return cls._result_items_to_records( table.query( IndexName="GSI-1", KeyConditionExpression=Key("GSI1-PK").eq(signal_key), ).get("Items", []) ) @classmethod def get_recent_items_page( cls, table: Table, exclusive_start_key: t.Optional[DynamoDBCursorKey] = None ) -> PaginatedResponse["MatchRecord"]: """ Get a paginated list of recent match records. Subsequent calls must use `return_value.last_evaluated_key`. """ if not exclusive_start_key: # Evidently, https://github.com/boto/boto3/issues/2813 boto is able # to distinguish fun(Parameter=None) from fun(). So, we can't use # exclusive_start_key's optionality. We have to do an if clause! # Fun! result = table.query( IndexName="GSI-2", Limit=100, ScanIndexForward=False, KeyConditionExpression=Key("GSI2-PK").eq( DynamoDBItem.get_dynamodb_type_key(cls.__name__) ), ) else: result = table.query( IndexName="GSI-2", Limit=100, ExclusiveStartKey=exclusive_start_key, ScanIndexForward=False, KeyConditionExpression=Key("GSI2-PK").eq( DynamoDBItem.get_dynamodb_type_key(cls.__name__) ), ) return PaginatedResponse( t.cast(DynamoDBCursorKey, result.get("LastEvaluatedKey", None)), cls._result_items_to_records(result["Items"]), ) @classmethod def _result_items_to_records( cls, items: t.List[t.Dict], ) -> t.List["MatchRecord"]: return [ MatchRecord( content_id=cls.remove_content_key_prefix(item["PK"]), content_hash=item["ContentHash"], updated_at=datetime.datetime.fromisoformat(item["UpdatedAt"]), signal_type=get_signal_types_by_name()[item["SignalType"]], signal_id=cls.remove_signal_key_prefix( item["SK"], item["SignalSource"] ), signal_source=item["SignalSource"], signal_hash=item["SignalHash"], signal_specific_attributes=cls.deserialize_signal_specific_attributes( item ), match_distance=item.get("MatchDistance"), ) for item in items ] @dataclass(eq=True) class HashRecord(CSViable): """ We are getting these records, with content_hashes and content_ids from the hashing process with intent to build an PDQIndex """ content_hash: str content_id: str def to_csv(self) -> t.List[t.Union[str, int]]: return [self.content_hash, self.content_id] @classmethod def from_csv(cls: t.Type["HashRecord"], value: t.List[str]) -> "HashRecord": return HashRecord(value[0], value[1])