hasher-matcher-actioner/hmalib/writebacker/writebacker_base.py (244 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import typing as t import os from functools import lru_cache from dataclasses import dataclass from hmalib.common.logging import get_logger from hmalib.common.classification_models import WritebackTypes from hmalib.common.messages.match import BankedSignal from hmalib.common.messages.writeback import WritebackMessage from hmalib.common.configs.fetcher import ThreatExchangeConfig from hmalib.common.mocks import MockedThreatExchangeAPI from hmalib.aws_secrets import AWSSecrets from threatexchange.api import ThreatExchangeAPI logger = get_logger(__name__) TE_UPLOAD_TAG = "uploaded_by_hma" class Writebacker: """ For writing back to an HMA data soruce (eg. ThreatExchange). Every source that enables writebacks should have an implmentation of this class (eg ThreatExchangeWritebacker) and optionally sub implementations (eg ThreatExchangeFalsePositiveWritebacker) You must also add the subclass you are implementing to the performable_subclasses fucntion below """ @property def source(self) -> str: """ The source that this writebacker corresponds to (eg. "te") """ raise NotImplementedError @staticmethod def writeback_options() -> t.Dict[ WritebackTypes.WritebackType, t.Type["Writebacker"] ]: """ For a given source that performs writebacks, this fucntion specifies what types of writebacks that can be taken as a mapping from writeback type to writebacker. The type should be same as WritebackType passed to the writebacker """ raise NotImplementedError @classmethod @lru_cache(maxsize=None) def get_writebacker_for_source(cls, source: str) -> t.Optional["Writebacker"]: if cls.__name__ != "Writebacker": raise ValueError( "get_writebacker_for_source can only be called from the Writebacker class directly. eg Writebacker().get_writebacker_for_source" ) sources_to_writebacker_cls = { writebacker_cls().source: writebacker_cls for writebacker_cls in cls.__subclasses__() } if source not in sources_to_writebacker_cls.keys(): return None return sources_to_writebacker_cls[source]() def writeback_is_enabled(self, writeback_signal: BankedSignal) -> bool: """ Users can switch on/off writebacks either globally for individual sources, or based on the matched signal """ raise NotImplementedError @property def writeback_type(self) -> WritebackTypes.WritebackType: """ The writeback label for when this action should be performed (eg WritebackType.SawThisToo) """ raise NotImplementedError def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]: raise NotImplementedError def perform_writeback(self, writeback_message: WritebackMessage) -> t.List[str]: writeback_to_perform = writeback_message.writeback_type error = None if writeback_to_perform not in self.writeback_options(): error = ( "Could not find writebacker for source " + self.source + " that can perform writeback " + writeback_to_perform.value ) logger.error(error) return [error] results = [] writebacker = self.writeback_options()[writeback_to_perform]() for writeback_signal in writeback_message.banked_signals: # filter our matches from other sources if writeback_signal.bank_source == self.source: result = None if writebacker.writeback_is_enabled(writeback_signal): result = writebacker._writeback_impl(writeback_signal) else: result = [ ( "No writeback performed for banked content id " + writeback_signal.banked_content_id + " becuase writebacks were disabled" ) ] for log_message in result: if "Error" in log_message: logger.error(log_message) else: logger.info(log_message) results.append("\n".join(result)) return results @dataclass class ThreatExchangeWritebacker(Writebacker): """ Writebacker parent object for all writebacks to ThreatExchange """ source = "te" @staticmethod @lru_cache(maxsize=None) def writeback_options() -> t.Dict[ WritebackTypes.WritebackType, t.Type["Writebacker"] ]: return { WritebackTypes.FalsePositive: ThreatExchangeFalsePositiveWritebacker, WritebackTypes.TruePositive: ThreatExchangeTruePositiveWritebacker, WritebackTypes.SawThisToo: ThreatExchangeSawThisTooWritebacker, WritebackTypes.RemoveOpinion: ThreatExchangeRemoveOpinionWritebacker, } def writeback_is_enabled(self, writeback_signal: BankedSignal) -> bool: privacy_group_id = writeback_signal.bank_id privacy_group_config = ThreatExchangeConfig.cached_get(privacy_group_id) if isinstance(privacy_group_config, ThreatExchangeConfig): return privacy_group_config.write_back # If no config, dont write back logger.warn("No config found for privacy group " + str(privacy_group_id)) return False @property def te_api(self) -> ThreatExchangeAPI: mock_te_api = os.environ.get("MOCK_TE_API") if mock_te_api == "True": return MockedThreatExchangeAPI() api_token = AWSSecrets().te_api_token() return ThreatExchangeAPI(api_token) def my_descriptor_from_all_descriptors( self, all_descriptors: t.List[t.Dict[str, t.Any]] ) -> t.Optional[t.Dict[str, t.Any]]: """ Given all descriptors for an indicator, find the one my app owns if it exists """ for descriptor in all_descriptors: if descriptor["owner"]["id"] == str(self.te_api.app_id): # some fields such as privacy_members can only be loaded for # descriptors we own. We make another api call to load these # fields and then merge with the existing data fields = ["privacy_members"] descriptor_with_private_data = self.te_api.get_threat_descriptors( [descriptor["id"]], fields=fields )[0] descriptor.update(descriptor_with_private_data) return descriptor return None class ThreatExchangeTruePositiveWritebacker(ThreatExchangeWritebacker): """ For writing back to ThreatExhcnage that the user belives the match was correct. Executing perform_writeback on this class will read the (indicator, privacy_group) pairs for the signal and upsert a new descriptor for that indicator with the privacy group for this collaboration """ def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]: privacy_group_id = writeback_signal.bank_id indicator_id = writeback_signal.banked_content_id descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id) my_descriptor = self.my_descriptor_from_all_descriptors(descriptors) postParams = { "privacy_type": "HAS_PRIVACY_GROUP", "expire_time": 0, "privacy_members": str(privacy_group_id), "review_status": "REVIEWED_MANUALLY", "status": "MALICIOUS", } # If we already have a descriptor we can copy it and re-upload to ensure it never expires if my_descriptor: members = {member for member in my_descriptor.get("privacy_members", [])} postParams["privacy_members"] = ",".join( members.union({str(privacy_group_id)}) ) postParams["descriptor_id"] = my_descriptor["id"] # This doesnt actually copy to a new descriptor but acts like an upsert for # the properties specified response = self.te_api.copy_threat_descriptor(postParams, False, False) else: postParams["indicator"] = descriptors[0]["indicator"]["indicator"] postParams["type"] = descriptors[0]["type"] postParams["description"] = "A ThreatDescriptor uploaded via HMA" postParams["share_level"] = "RED" postParams["tags"] = TE_UPLOAD_TAG response = self.te_api.upload_threat_descriptor(postParams, False, False) error = response[1] or response[2].get("error", {}).get("message") if error: return [ f""" Error writing back TruePositive for indicator {writeback_signal.banked_content_id} Error: {error} """.strip() ] return [ f"Wrote back TruePositive for indicator {writeback_signal.banked_content_id}", f"{'Built' if my_descriptor else 'Updated'} descriptor {response[2]['id']} with privacy groups {postParams['privacy_members']}", ] class ThreatExchangeRemoveOpinionWritebacker(ThreatExchangeWritebacker): """ For writing back to ThreatExhcnage that the user belives the match was correct. Executing perform_writeback on this class will try to remove both TruePositive and FalsePositive opinions if they exist. To remove a FalsePositive opinion we load the indicator and find all associated descriptors. Then, for each indicator, if the user has reacted DISAGREE_WITH_TAGS, remove that reaction. To remove a TruePositive opinion we need to remove the apps descriptor from the collaboration. To do this, we load the (indicator, privacy_group) and find a ThreatDescriptor that the user has created for that indicator. We then remove the privacy group from that descriptor if it exists thereby removing it from the collaboration. If there are no more privacy groups we delete the indicator. """ def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]: privacy_group_id = writeback_signal.bank_id indicator_id = writeback_signal.banked_content_id descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id) my_descriptor = self.my_descriptor_from_all_descriptors(descriptors) other_desriptors = [ d for d in descriptors if d["owner"]["id"] != str(self.te_api.app_id) ] logs = [] if my_descriptor: # ensure property exists my_descriptor["indicator_id"] = indicator_id if my_descriptor["privacy_type"] != "HAS_PRIVACY_GROUP": # We currently can't add/remove a true positive opinion if the descriptor has a # privacy type other than HAS_PRIVACY_GROUP We will still try to remove false positive opinions logs.append( f"Error writing back RemoveOpinion for indicator {my_descriptor['indicator_id']}\n Error: Cannot remove/edit descriptor {my_descriptor['id']} because the privacy type is not HAS_PRIVACY_GROUP" ) else: logs.append( self.remove_descriptor_from_privacy_group( my_descriptor, privacy_group_id ) ) else: logs.append( f"No descriptor to remove for indicator {writeback_signal.banked_content_id}" ) logs.extend(self.remove_false_positive_from_descriptors(other_desriptors)) return logs def remove_descriptor_from_privacy_group( self, my_descriptor: t.Dict[str, t.Any], privacy_group_id: str ) -> str: new_privacy_groups = [ pg for pg in my_descriptor["privacy_members"] if isinstance(pg, str) and pg != privacy_group_id ] if not new_privacy_groups: response = self.te_api.delete_threat_descriptor( my_descriptor["id"], True, False ) else: postParams = { "privacy_members": ",".join(new_privacy_groups), "privacy_type": "HAS_PRIVACY_GROUP", "descriptor_id": my_descriptor["id"], } response = self.te_api.copy_threat_descriptor(postParams, True, False) error = response[1] or response[2].get("error", {}).get("message") if error: return f"Error writing back RemoveOpinion for indicator {my_descriptor['indicator_id']} Error: {error}" else: return f"Deleted decriptor {my_descriptor['id']} for indicator {my_descriptor['indicator_id']}" def remove_false_positive_from_descriptors(self, descriptors) -> t.List[str]: ret = [] reaction = ThreatExchangeFalsePositiveWritebacker.reaction for descriptor in descriptors: id = descriptor["id"] self.te_api.remove_reaction_from_threat_descriptor(id, reaction) ret.append(f"Removed reaction {reaction} from descriptor {id}") return ret @dataclass class ThreatExchangeReactionWritebacker(ThreatExchangeWritebacker): """ For all writebacks to ThreatExchange that are implemented as adding reactions. Executing perform_writeback on this class will read the indicators from the match, load all related descriptors, and write the given reaction to them """ @property def reaction(self) -> str: raise NotImplementedError def _writeback_impl(self, writeback_signal: BankedSignal) -> t.List[str]: indicator_id = writeback_signal.banked_content_id descriptors = self.te_api.get_threat_descriptors_from_indicator(indicator_id) other_desriptors = [ d for d in descriptors if d["owner"]["id"] != str(self.te_api.app_id) ] logs = [] for descriptor in other_desriptors: id = descriptor["id"] result = self.te_api.react_to_threat_descriptor(id, self.reaction) error = result[1] or result[2].get("error", {}).get("message") if error: logs.append( f"Error writing back Reacting {self.reaction} to descriptor {id} Error: {error}" ) else: logs.append(f"Reacted {self.reaction} to descriptor {id}") return logs class ThreatExchangeFalsePositiveWritebacker(ThreatExchangeReactionWritebacker): """ For writing back to ThreatExhcnage that the user belives the match was a false positive. """ reaction = "DISAGREE_WITH_TAGS" # TODO: Currently writing back INGESTED fails becuase of API limits. Need to # solve before sending reaction. Possible solution to create new batch react endpoint # class ThreatExchangeIngestedWritebacker(ThreatExchangeReactionWritebacker): # reaction = "INGESTED" class ThreatExchangeSawThisTooWritebacker(ThreatExchangeReactionWritebacker): """ For writing back to ThreatExhcnage that a Match has occurred """ reaction = "SAW_THIS_TOO" if __name__ == "__main__": pass