hasher-matcher-actioner/hmalib/writebacker/writebacker_base.py (244 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import typing as t
import os
from functools import lru_cache
from dataclasses import dataclass
from hmalib.common.logging import get_logger
from hmalib.common.classification_models import WritebackTypes
from hmalib.common.messages.match import BankedSignal
from hmalib.common.messages.writeback import WritebackMessage
from hmalib.common.configs.fetcher import ThreatExchangeConfig
from hmalib.common.mocks import MockedThreatExchangeAPI
from hmalib.aws_secrets import AWSSecrets
from threatexchange.api import ThreatExchangeAPI
logger = get_logger(__name__)
TE_UPLOAD_TAG = "uploaded_by_hma"
class Writebacker:
"""
For writing back to an HMA data soruce (eg. ThreatExchange). Every source that
enables writebacks should have an implmentation of this class
(eg ThreatExchangeWritebacker) and optionally sub implementations
(eg ThreatExchangeFalsePositiveWritebacker)
You must also add the subclass you are implementing to the performable_subclasses
fucntion below
"""
@property
def source(self) -> str:
"""
The source that this writebacker corresponds to (eg. "te")
"""
raise NotImplementedError
@staticmethod
def writeback_options() -> t.Dict[
WritebackTypes.WritebackType, t.Type["Writebacker"]
]:
"""
For a given source that performs writebacks, this fucntion specifies what types of
writebacks that can be taken as a mapping from writeback type to writebacker. The
type should be same as WritebackType passed to the writebacker
"""
raise NotImplementedError
@classmethod
@lru_cache(maxsize=None)
def get_writebacker_for_source(cls, source: str) -> t.Optional["Writebacker"]:
if cls.__name__ != "Writebacker":
raise ValueError(
"get_writebacker_for_source can only be called from the Writebacker class directly. eg Writebacker().get_writebacker_for_source"
)
sources_to_writebacker_cls = {
writebacker_cls().source: writebacker_cls
for writebacker_cls in cls.__subclasses__()
}
if source not in sources_to_writebacker_cls.keys():
return None
return sources_to_writebacker_cls[source]()
def writeback_is_enabled(self, writeback_signal: BankedSignal) -> bool:
"""
Users can switch on/off writebacks either globally for individual sources, or based on the matched signal
"""
raise NotImplementedError
@property
def writeback_type(self) -> WritebackTypes.WritebackType:
"""
The writeback label for when this action should be performed (eg WritebackType.SawThisToo)
"""
raise NotImplementedError
def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]:
raise NotImplementedError
def perform_writeback(self, writeback_message: WritebackMessage) -> t.List[str]:
writeback_to_perform = writeback_message.writeback_type
error = None
if writeback_to_perform not in self.writeback_options():
error = (
"Could not find writebacker for source "
+ self.source
+ " that can perform writeback "
+ writeback_to_perform.value
)
logger.error(error)
return [error]
results = []
writebacker = self.writeback_options()[writeback_to_perform]()
for writeback_signal in writeback_message.banked_signals:
# filter our matches from other sources
if writeback_signal.bank_source == self.source:
result = None
if writebacker.writeback_is_enabled(writeback_signal):
result = writebacker._writeback_impl(writeback_signal)
else:
result = [
(
"No writeback performed for banked content id "
+ writeback_signal.banked_content_id
+ " becuase writebacks were disabled"
)
]
for log_message in result:
if "Error" in log_message:
logger.error(log_message)
else:
logger.info(log_message)
results.append("\n".join(result))
return results
@dataclass
class ThreatExchangeWritebacker(Writebacker):
"""
Writebacker parent object for all writebacks to ThreatExchange
"""
source = "te"
@staticmethod
@lru_cache(maxsize=None)
def writeback_options() -> t.Dict[
WritebackTypes.WritebackType, t.Type["Writebacker"]
]:
return {
WritebackTypes.FalsePositive: ThreatExchangeFalsePositiveWritebacker,
WritebackTypes.TruePositive: ThreatExchangeTruePositiveWritebacker,
WritebackTypes.SawThisToo: ThreatExchangeSawThisTooWritebacker,
WritebackTypes.RemoveOpinion: ThreatExchangeRemoveOpinionWritebacker,
}
def writeback_is_enabled(self, writeback_signal: BankedSignal) -> bool:
privacy_group_id = writeback_signal.bank_id
privacy_group_config = ThreatExchangeConfig.cached_get(privacy_group_id)
if isinstance(privacy_group_config, ThreatExchangeConfig):
return privacy_group_config.write_back
# If no config, dont write back
logger.warn("No config found for privacy group " + str(privacy_group_id))
return False
@property
def te_api(self) -> ThreatExchangeAPI:
mock_te_api = os.environ.get("MOCK_TE_API")
if mock_te_api == "True":
return MockedThreatExchangeAPI()
api_token = AWSSecrets().te_api_token()
return ThreatExchangeAPI(api_token)
def my_descriptor_from_all_descriptors(
self, all_descriptors: t.List[t.Dict[str, t.Any]]
) -> t.Optional[t.Dict[str, t.Any]]:
"""
Given all descriptors for an indicator, find the one my app owns
if it exists
"""
for descriptor in all_descriptors:
if descriptor["owner"]["id"] == str(self.te_api.app_id):
# some fields such as privacy_members can only be loaded for
# descriptors we own. We make another api call to load these
# fields and then merge with the existing data
fields = ["privacy_members"]
descriptor_with_private_data = self.te_api.get_threat_descriptors(
[descriptor["id"]], fields=fields
)[0]
descriptor.update(descriptor_with_private_data)
return descriptor
return None
class ThreatExchangeTruePositiveWritebacker(ThreatExchangeWritebacker):
"""
For writing back to ThreatExhcnage that the user belives the match was
correct.
Executing perform_writeback on this class will read the (indicator, privacy_group)
pairs for the signal and upsert a new descriptor for that indicator with the
privacy group for this collaboration
"""
def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]:
privacy_group_id = writeback_signal.bank_id
indicator_id = writeback_signal.banked_content_id
descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id)
my_descriptor = self.my_descriptor_from_all_descriptors(descriptors)
postParams = {
"privacy_type": "HAS_PRIVACY_GROUP",
"expire_time": 0,
"privacy_members": str(privacy_group_id),
"review_status": "REVIEWED_MANUALLY",
"status": "MALICIOUS",
}
# If we already have a descriptor we can copy it and re-upload to ensure it never expires
if my_descriptor:
members = {member for member in my_descriptor.get("privacy_members", [])}
postParams["privacy_members"] = ",".join(
members.union({str(privacy_group_id)})
)
postParams["descriptor_id"] = my_descriptor["id"]
# This doesnt actually copy to a new descriptor but acts like an upsert for
# the properties specified
response = self.te_api.copy_threat_descriptor(postParams, False, False)
else:
postParams["indicator"] = descriptors[0]["indicator"]["indicator"]
postParams["type"] = descriptors[0]["type"]
postParams["description"] = "A ThreatDescriptor uploaded via HMA"
postParams["share_level"] = "RED"
postParams["tags"] = TE_UPLOAD_TAG
response = self.te_api.upload_threat_descriptor(postParams, False, False)
error = response[1] or response[2].get("error", {}).get("message")
if error:
return [
f"""
Error writing back TruePositive for indicator {writeback_signal.banked_content_id}
Error: {error}
""".strip()
]
return [
f"Wrote back TruePositive for indicator {writeback_signal.banked_content_id}",
f"{'Built' if my_descriptor else 'Updated'} descriptor {response[2]['id']} with privacy groups {postParams['privacy_members']}",
]
class ThreatExchangeRemoveOpinionWritebacker(ThreatExchangeWritebacker):
"""
For writing back to ThreatExhcnage that the user belives the match was
correct.
Executing perform_writeback on this class will try to remove both
TruePositive and FalsePositive opinions if they exist.
To remove a FalsePositive opinion we load the indicator and find all
associated descriptors. Then, for each indicator, if the user has
reacted DISAGREE_WITH_TAGS, remove that reaction.
To remove a TruePositive opinion we need to remove the apps descriptor
from the collaboration. To do this, we load the (indicator, privacy_group)
and find a ThreatDescriptor that the user has created for that indicator.
We then remove the privacy group from that descriptor if it exists thereby
removing it from the collaboration. If there are no more privacy groups we
delete the indicator.
"""
def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]:
privacy_group_id = writeback_signal.bank_id
indicator_id = writeback_signal.banked_content_id
descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id)
my_descriptor = self.my_descriptor_from_all_descriptors(descriptors)
other_desriptors = [
d for d in descriptors if d["owner"]["id"] != str(self.te_api.app_id)
]
logs = []
if my_descriptor:
# ensure property exists
my_descriptor["indicator_id"] = indicator_id
if my_descriptor["privacy_type"] != "HAS_PRIVACY_GROUP":
# We currently can't add/remove a true positive opinion if the descriptor has a
# privacy type other than HAS_PRIVACY_GROUP We will still try to remove false positive opinions
logs.append(
f"Error writing back RemoveOpinion for indicator {my_descriptor['indicator_id']}\n Error: Cannot remove/edit descriptor {my_descriptor['id']} because the privacy type is not HAS_PRIVACY_GROUP"
)
else:
logs.append(
self.remove_descriptor_from_privacy_group(
my_descriptor, privacy_group_id
)
)
else:
logs.append(
f"No descriptor to remove for indicator {writeback_signal.banked_content_id}"
)
logs.extend(self.remove_false_positive_from_descriptors(other_desriptors))
return logs
def remove_descriptor_from_privacy_group(
self, my_descriptor: t.Dict[str, t.Any], privacy_group_id: str
) -> str:
new_privacy_groups = [
pg
for pg in my_descriptor["privacy_members"]
if isinstance(pg, str) and pg != privacy_group_id
]
if not new_privacy_groups:
response = self.te_api.delete_threat_descriptor(
my_descriptor["id"], True, False
)
else:
postParams = {
"privacy_members": ",".join(new_privacy_groups),
"privacy_type": "HAS_PRIVACY_GROUP",
"descriptor_id": my_descriptor["id"],
}
response = self.te_api.copy_threat_descriptor(postParams, True, False)
error = response[1] or response[2].get("error", {}).get("message")
if error:
return f"Error writing back RemoveOpinion for indicator {my_descriptor['indicator_id']} Error: {error}"
else:
return f"Deleted decriptor {my_descriptor['id']} for indicator {my_descriptor['indicator_id']}"
def remove_false_positive_from_descriptors(self, descriptors) -> t.List[str]:
ret = []
reaction = ThreatExchangeFalsePositiveWritebacker.reaction
for descriptor in descriptors:
id = descriptor["id"]
self.te_api.remove_reaction_from_threat_descriptor(id, reaction)
ret.append(f"Removed reaction {reaction} from descriptor {id}")
return ret
@dataclass
class ThreatExchangeReactionWritebacker(ThreatExchangeWritebacker):
"""
For all writebacks to ThreatExchange that are implemented as adding reactions.
Executing perform_writeback on this class will read the indicators
from the match, load all related descriptors, and write the given reaction
to them
"""
@property
def reaction(self) -> str:
raise NotImplementedError
def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]:
indicator_id = writeback_signal.banked_content_id
descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id)
other_desriptors = [
d for d in descriptors if d["owner"]["id"] != str(self.te_api.app_id)
]
logs = []
for descriptor in other_desriptors:
id = descriptor["id"]
result = self.te_api.react_to_threat_descriptor(id, self.reaction)
error = result[1] or result[2].get("error", {}).get("message")
if error:
logs.append(
f"Error writing back Reacting {self.reaction} to descriptor {id} Error: {error}"
)
else:
logs.append(f"Reacted {self.reaction} to descriptor {id}")
return logs
class ThreatExchangeFalsePositiveWritebacker(ThreatExchangeReactionWritebacker):
"""
For writing back to ThreatExhcnage that the user belives the match was
a false positive.
"""
reaction = "DISAGREE_WITH_TAGS"
# TODO: Currently writing back INGESTED fails becuase of API limits. Need to
# solve before sending reaction. Possible solution to create new batch react endpoint
# class ThreatExchangeIngestedWritebacker(ThreatExchangeReactionWritebacker):
# reaction = "INGESTED"
class ThreatExchangeSawThisTooWritebacker(ThreatExchangeReactionWritebacker):
"""
For writing back to ThreatExhcnage that a Match has occurred
"""
reaction = "SAW_THIS_TOO"
if __name__ == "__main__":
pass