hasher-matcher-actioner/hmalib/lambdas/api/matches.py (408 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import boto3
import bottle
import functools
import dataclasses
from dataclasses import dataclass, asdict
from mypy_boto3_dynamodb.service_resource import Table
import typing as t
from enum import Enum
from urllib.error import HTTPError
from mypy_boto3_sqs.client import SQSClient
from threatexchange.descriptor import ThreatDescriptor
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.video import VideoContent
from threatexchange.content_type.meta import (
get_signal_types_by_name,
get_content_types_by_name,
)
from hmalib.common.models.pipeline import MatchRecord
from hmalib.common.models.signal import (
ThreatExchangeSignalMetadata,
PendingThreatExchangeOpinionChange,
)
from hmalib.common.logging import get_logger
from hmalib.common.messages.match import BankedSignal
from hmalib.common.messages.writeback import WritebackMessage
from hmalib.indexers.metadata import (
BANKS_SOURCE_SHORT_CODE,
BankedSignalIndexMetadata,
)
from hmalib.lambdas.api.middleware import (
jsoninator,
JSONifiable,
DictParseable,
SubApp,
)
from hmalib.common.config import HMAConfig
from hmalib.common.models.bank import BankMember, BanksTable
from hmalib.matchers.matchers_base import Matcher
from hmalib.hashing.unified_hasher import UnifiedHasher
from hmalib.common.content_sources import URLContentSource
logger = get_logger(__name__)
@functools.lru_cache(maxsize=None)
def _get_sqs_client() -> SQSClient:
return boto3.client("sqs")
_matcher = None
def _get_matcher(index_bucket_name: str, banks_table: BanksTable) -> Matcher:
global _matcher
if _matcher is None:
_matcher = Matcher(
index_bucket_name=index_bucket_name,
supported_signal_types=[PdqSignal, VideoMD5Signal],
banks_table=banks_table,
)
return _matcher
_hasher = None
def _get_hasher() -> UnifiedHasher:
global _hasher
if _hasher is None:
_hasher = UnifiedHasher(
supported_content_types=[PhotoContent, VideoContent],
supported_signal_types=[PdqSignal, VideoMD5Signal],
output_queue_url="",
)
return _hasher
@dataclass
class MatchSummary(JSONifiable):
content_id: str
signal_id: t.Union[str, int]
signal_source: str
updated_at: str
def to_json(self) -> t.Dict:
return asdict(self)
@dataclass
class MatchSummariesResponse(JSONifiable):
match_summaries: t.List[MatchSummary]
def to_json(self) -> t.Dict:
return {
"match_summaries": [summary.to_json() for summary in self.match_summaries]
}
@dataclass
class BankedSignalDetailsMetadata(JSONifiable):
bank_member_id: str
bank_id: str
def to_json(self) -> t.Dict:
return asdict(self)
@dataclass
class ThreatExchangeSignalDetailsMetadata(JSONifiable):
privacy_group_id: str
tags: t.List[str]
opinion: str
pending_opinion_change: str
def to_json(self) -> t.Dict:
return asdict(self)
@dataclass
class MatchDetail(JSONifiable):
"""
Note: te_signal_details should eventaully be folded into banked_signal_details
once threatexchanges signals function the same as locally banked one.
"""
content_id: str
content_hash: str
signal_id: t.Union[str, int]
signal_hash: str
signal_source: str
signal_type: str
updated_at: str
match_distance: t.Optional[int]
te_signal_details: t.List[ThreatExchangeSignalDetailsMetadata]
banked_signal_details: t.List[BankedSignalDetailsMetadata]
def to_json(self) -> t.Dict:
result = asdict(self)
result.update(
te_signal_details=[datum.to_json() for datum in self.te_signal_details]
)
result.update(
banked_signal_details=[
datum.to_json() for datum in self.banked_signal_details
]
)
return result
@dataclass
class MatchDetailsResponse(JSONifiable):
match_details: t.List[MatchDetail]
def to_json(self) -> t.Dict:
return {"match_details": [detail.to_json() for detail in self.match_details]}
@dataclass
class ChangeSignalOpinionResponse(JSONifiable):
success: bool
def to_json(self) -> t.Dict:
return {"change_requested": self.success}
def get_match_details(
datastore_table: Table, banks_table: BanksTable, content_id: str
) -> t.List[MatchDetail]:
if not content_id:
return []
records = MatchRecord.get_from_content_id(datastore_table, f"{content_id}")
return [
MatchDetail(
content_id=record.content_id,
content_hash=record.content_hash,
signal_id=record.signal_id,
signal_hash=record.signal_hash,
signal_source=record.signal_source,
signal_type=record.signal_type.get_name(),
updated_at=record.updated_at.isoformat(),
match_distance=int(record.match_distance)
if record.match_distance is not None
else None,
te_signal_details=get_te_signal_details(
datastore_table=datastore_table,
signal_id=record.signal_id,
signal_source=record.signal_source,
),
banked_signal_details=get_banked_signal_details(
banks_table=banks_table,
signal_id=record.signal_id,
signal_source=record.signal_source,
),
)
for record in records
]
def get_te_signal_details(
datastore_table: Table,
signal_id: str,
signal_source: str,
) -> t.List[ThreatExchangeSignalDetailsMetadata]:
"""
Note: te_signal_details should eventaully be folded into banked_signal_details
once threatexchanges signals function the same as locally banked one.
"""
if not signal_id or not signal_source or signal_source == BANKS_SOURCE_SHORT_CODE:
return []
return [
ThreatExchangeSignalDetailsMetadata(
privacy_group_id=metadata.privacy_group_id,
tags=[
tag for tag in metadata.tags if tag not in ThreatDescriptor.SPECIAL_TAGS
],
opinion=get_opinion_from_tags(metadata.tags).value,
pending_opinion_change=metadata.pending_opinion_change.value,
)
for metadata in ThreatExchangeSignalMetadata.get_from_signal(
datastore_table, signal_id
)
]
def get_banked_signal_details(
banks_table: BanksTable,
signal_id: str,
signal_source: str,
) -> t.List[BankedSignalDetailsMetadata]:
if not signal_id or not signal_source or signal_source != BANKS_SOURCE_SHORT_CODE:
return []
return [
BankedSignalDetailsMetadata(
bank_member_id=bank_member_signal.bank_member_id,
bank_id=bank_member_signal.bank_id,
)
for bank_member_signal in banks_table.get_bank_member_signal_from_id(signal_id)
]
class OpinionString(Enum):
TRUE_POSITIVE = "True Positive"
FALSE_POSITIVE = "False Positive"
DISPUTED = "Disputed"
UNKNOWN = "Unknown"
def get_opinion_from_tags(tags: t.List[str]) -> OpinionString:
# see python-threatexchange descriptor.py for origins
if ThreatDescriptor.TRUE_POSITIVE in tags:
return OpinionString.TRUE_POSITIVE
if ThreatDescriptor.FALSE_POSITIVE in tags:
return OpinionString.FALSE_POSITIVE
if ThreatDescriptor.DISPUTED in tags:
return OpinionString.DISPUTED
return OpinionString.UNKNOWN
@dataclass
class MatchesForHashRequest(DictParseable):
signal_value: str
signal_type: t.Type[SignalType]
@classmethod
def from_dict(cls, d):
base = cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)}) # type: ignore # tiny hack to get convert string fields
base.signal_type = get_signal_types_by_name()[base.signal_type] # type: ignore # tiny hack to get SignalType from string in request
return base
@dataclass
class MatchesForMediaRequest(DictParseable):
content_url: str
content_type: t.Type[ContentType]
@classmethod
def from_dict(cls, d):
base = cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)}) # type: ignore # tiny hack to get convert string fields
base.content_type = get_content_types_by_name()[base.content_type] # type: ignore # tiny hack to get ContentType from string in request
return base
@dataclass
class MatchesForHash(JSONifiable):
match_distance: int
# TODO: Once ThreatExchange data flows into Banks, we can Use BankMember
# alone.
matched_signal: t.Union[
ThreatExchangeSignalMetadata, BankMember
] # or matches signal from other sources
UNSUPPORTED_FIELDS = ["updated_at", "pending_opinion_change"]
def to_json(self) -> t.Dict:
return {
"match_distance": self.match_distance,
"matched_signal": self._remove_unsupported_fields(
self.matched_signal.to_json()
),
}
@classmethod
def _remove_unsupported_fields(cls, matched_signal: t.Dict) -> t.Dict:
"""
ThreatExchangeSignalMetadata is used to store metadata in dynamodb
and handle opinion changes on said signal. However the request this object
responds to only handles directly accessing the index. Because of this
not all fields of the object are relevant or accurate.
"""
for field in cls.UNSUPPORTED_FIELDS:
try:
del matched_signal[field]
except KeyError:
pass
return matched_signal
@dataclass
class MatchesForHashResponse(JSONifiable):
matches: t.List[MatchesForHash]
def to_json(self) -> t.Dict:
return {"matches": [match.to_json() for match in self.matches]}
@dataclass
class MatchesForMediaResponse(JSONifiable):
signal_to_matches: t.Dict[str, t.Dict[str, t.List[MatchesForHash]]]
# example: { "pdq" : { "<hash>" : [<matches?>] } }
def to_json(self) -> t.Dict:
return {
"signal_to_matches": {
signal_type: {signal_val: [match.to_json() for match in matches]}
for (signal_val, matches) in hash_to_match.items()
}
for (signal_type, hash_to_match) in self.signal_to_matches.items()
}
@dataclass
class MediaFetchError(JSONifiable):
def to_json(self) -> t.Dict:
return {
"message": "Failed to fetch media from provided url",
}
def get_matches_api(
datastore_table: Table,
hma_config_table: str,
indexes_bucket_name: str,
writeback_queue_url: str,
bank_table: Table,
) -> bottle.Bottle:
"""
A Closure that includes all dependencies that MUST be provided by the root
API that this API plugs into. Declare dependencies here, but initialize in
the root API alone.
"""
# A prefix to all routes must be provided by the api_root app
# The documentation below expects prefix to be '/matches/'
matches_api = SubApp()
HMAConfig.initialize(hma_config_table)
banks_table = BanksTable(table=bank_table)
@matches_api.get("/", apply=[jsoninator])
def matches() -> MatchSummariesResponse:
"""
Return all, or a filtered list of matches based on query params.
"""
signal_q = bottle.request.query.signal_q or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
signal_source = bottle.request.query.signal_source or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
content_q = bottle.request.query.content_q or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
if content_q:
records = MatchRecord.get_from_content_id(datastore_table, content_q)
elif signal_q:
records = MatchRecord.get_from_signal(
datastore_table, signal_q, signal_source or ""
)
else:
# TODO: Support pagination after implementing in UI.
records = MatchRecord.get_recent_items_page(datastore_table).items
return MatchSummariesResponse(
match_summaries=[
MatchSummary(
content_id=record.content_id,
signal_id=record.signal_id,
signal_source=record.signal_source,
updated_at=record.updated_at.isoformat(),
)
for record in records
]
)
@matches_api.get("/match/", apply=[jsoninator])
def match_details() -> MatchDetailsResponse:
"""
Return the match details for a given content id.
"""
results = []
if content_id := bottle.request.query.content_id or None: # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
results = get_match_details(
datastore_table=datastore_table,
banks_table=banks_table,
content_id=content_id,
)
return MatchDetailsResponse(match_details=results)
@matches_api.post("/request-signal-opinion-change/", apply=[jsoninator])
def request_signal_opinion_change() -> ChangeSignalOpinionResponse:
"""
Request a change to the opinion for a signal in a given privacy_group.
"""
signal_id = bottle.request.query.signal_id or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
signal_source = bottle.request.query.signal_source or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
privacy_group_id = bottle.request.query.privacy_group_id or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
opinion_change = bottle.request.query.opinion_change or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)``
if (
not signal_id
or not signal_source
or not privacy_group_id
or not opinion_change
):
return ChangeSignalOpinionResponse(False)
signal_id = str(signal_id)
pending_opinion_change = PendingThreatExchangeOpinionChange(opinion_change)
writeback_message = WritebackMessage.from_banked_signal_and_opinion_change(
BankedSignal(signal_id, privacy_group_id, signal_source),
pending_opinion_change,
)
writeback_message.send_to_queue(_get_sqs_client(), writeback_queue_url)
logger.info(
f"Opinion change enqueued for {signal_source}:{signal_id} in {privacy_group_id} change={opinion_change}"
)
signal = ThreatExchangeSignalMetadata.get_from_signal_and_privacy_group(
datastore_table, signal_id=signal_id, privacy_group_id=privacy_group_id
)
if not signal:
logger.error("Signal not found.")
signal = t.cast(ThreatExchangeSignalMetadata, signal)
signal.pending_opinion_change = pending_opinion_change
success = signal.update_pending_opinion_change_in_table_if_exists(
datastore_table
)
if not success:
logger.error(f"Attempting to update {signal} in db failed")
return ChangeSignalOpinionResponse(success)
def _matches_for_hash(
signal_type: t.Type[SignalType], signal_value: str
) -> t.List[MatchesForHash]:
matches = _get_matcher(indexes_bucket_name, banks_table=banks_table).match(
signal_type, signal_value
)
match_objects: t.List[MatchesForHash] = []
# First get all threatexchange objects
for match in matches:
match_objects.extend(
[
MatchesForHash(
match_distance=int(match.distance),
matched_signal=signal_metadata,
)
for signal_metadata in Matcher.get_te_metadata_objects_from_match(
signal_type, match
)
]
)
# now get all bank objects
for match in matches:
for metadata_obj in filter(
lambda m: m.get_source() == BANKS_SOURCE_SHORT_CODE, match.metadata
):
metadata_obj = t.cast(BankedSignalIndexMetadata, metadata_obj)
match_objects.append(
MatchesForHash(
match_distance=int(match.distance),
matched_signal=banks_table.get_bank_member(
metadata_obj.bank_member_id
),
)
)
return match_objects
@matches_api.get(
"/for-hash/", apply=[jsoninator(MatchesForHashRequest, from_query=True)]
)
def for_hash(request: MatchesForHashRequest) -> MatchesForHashResponse:
"""
For a given hash/signal check the index(es) for matches and return the details.
This does not change system state, metadata returned will not be written any tables
unlike when matches are found for submissions.
"""
return MatchesForHashResponse(
matches=_matches_for_hash(request.signal_type, request.signal_value)
)
@matches_api.post("/for-media/", apply=[jsoninator(MatchesForMediaRequest)])
def for_media(
request: MatchesForMediaRequest,
) -> t.Union[MatchesForMediaResponse, MediaFetchError]:
"""
For a given piece of media hash it, check the index(es) for matches, and return the details.
This does not change system state, metadata returned will not be written any tables
unlike when matches are found for submissions.
"""
try:
bytes_: bytes = URLContentSource().get_bytes(request.content_url)
except Exception as e:
bottle.response.status = 400
return MediaFetchError()
signal_to_matches = {}
for signal in _get_hasher().get_hashes(request.content_type, bytes_):
signal_to_matches[signal.signal_type.get_name()] = {
signal.signal_value: _matches_for_hash(
signal_type=signal.signal_type, signal_value=signal.signal_value
)
}
return MatchesForMediaResponse(signal_to_matches=signal_to_matches)
return matches_api