python-threatexchange/threatexchange/fb_threatexchange/api.py (407 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ This is an entire copy of a file from ThreatExchange/hashing TODO: Slim down to only what we need """ import copy import json import typing as t import os import pathlib import re import urllib.parse import urllib.error import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry from .api_representations import ThreatPrivacyGroup def is_valid_app_token(token: str) -> bool: """Returns true if the string looks like a valid token""" return bool(re.match("[0-9]{8,}(?:%7C|\\|)[a-zA-Z0-9_\\-]{20,}", token)) class TimeoutHTTPAdapter(HTTPAdapter): """ Plug into requests to get a well-behaved session that does not wait for eternity. H/T: https://findwork.dev/blog/advanced-usage-python-requests-timeouts-retries-hooks/#setting-default-timeouts """ def __init__(self, *args, timeout=5, **kwargs): self.timeout = timeout super().__init__(*args, **kwargs) def send(self, request, *, timeout=None, **kwargs): if timeout is None: timeout = self.timeout return super().send(request, timeout=timeout, **kwargs) class _CursoredResponse: """Wrapper around paginated responses from Graph API""" def __init__(self, api: "ThreatExchangeAPI", url, params, decode_fn=None) -> None: self.api = api self.response = None self.next_url = url self.params = params self.data: t.List = [] self.decode_fn = decode_fn @property def done(self): return self.next_url is None def next(self): if self.done: return [] response = self.api.get_json_from_url(self.next_url, self.params) next_url = response.get("paging", {}).get("next") data = response.get("data", []) if self.decode_fn: data = [self.decode_fn(x) for x in data] self.next_url = next_url self.data = data self.params.clear() return self.data def __iter__(self): while not self.done: self.next() if self.data is not None: yield self.data class ThreatExchangeAPI: _TE_BASE_URL = "https://graph.facebook.com/v9.0" # This is just a keystroke-saver / error-avoider for passing around # post-parameter field names. _POST_PARAM_NAMES = { "indicator": "indicator", # For submit "type": "type", # For submit "descriptor_id": "descriptor_id", # For update "description": "description", "share_level": "share_level", "status": "status", "privacy_type": "privacy_type", "privacy_members": "privacy_members", "tags": "tags", "add_tags": "add_tags", "remove_tags": "remove_tags", "confidence": "confidence", "precision": "precision", "review_status": "review_status", "severity": "severity", "expired_on": "expired_on", "first_active": "first_active", "last_active": "last_active", "related_ids_for_upload": "related_ids_for_upload", "related_triples_for_upload_as_json": "related_triples_for_upload_as_json", # Legacy : should have been named reactions_to_add, but isn't. :( "reactions": "reactions", "reactions_to_remove": "reactions_to_remove", } def __init__( self, api_token: str, *, endpoint_override: t.Optional[str] = None ) -> None: self.api_token = api_token self._base_url = endpoint_override or self._TE_BASE_URL @property def app_id(self): return int(self.api_token.partition("|")[0]) def get_json_from_url(self, url, params=None, *, json_obj_hook: t.Callable = None): """ Perform an HTTP GET request, and return the JSON response payload. Same timeouts and retry strategy as `_get_session` above. """ with self._get_session() as session: response = requests.get(url, params=params or {}) response.raise_for_status() return response.json(object_hook=json_obj_hook) def _get_session(self): """ Custom requests sesson Ideally, should be used within a context manager: ``` with self._get_session() as session: session.get()... ``` If using without a context manager, ensure you end up calling close() on the returned value. """ session = requests.Session() session.mount( self._base_url, adapter=TimeoutHTTPAdapter( timeout=60, max_retries=Retry( total=4, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["HEAD", "GET", "OPTIONS"], backoff_factor=0.2, # ~1.5 seconds of retries ), ), ) return session def get_tag_id(self, tagName, showURLs=False): """ Looks up the "objective tag" ID for a given tag. This is suitable input for the /threat_tags endpoint. """ url = ( self._base_url + "/threat_tags" + "/?access_token=" + self.api_token + "&text=" + urllib.parse.quote(tagName) ) if showURLs: print("URL:") print(url) response = self.get_json_from_url(url) # The lookup will get everything that has this as a prefix. # So we need to filter the results. This loop also handles the # case when the results array is empty. # # Example: when querying for "media_type_video", we want the 2nd one: # { "data": [ # { "id": "9999338563303771", "text": "media_type_video_long_hash" }, # { "id": "9999474908560728", "text": "media_type_video" }, # { "id": "9889872714202918", "text": "media_type_video_hash_long" } # ], ... # } data = response["data"] desired = list(filter(lambda o: o["text"] == tagName, data)) if len(desired) < 1: return None else: return desired[0]["id"] def get_threat_descriptors(self, ids, **kwargs): """ Looks up all metadata for given IDs. """ verbose = kwargs.get("verbose", False) showURLs = kwargs.get("showURLs", False) includeIndicatorInOutput = kwargs.get("includeIndicatorInOutput", True) default_fields = [ "raw_indicator", "type", "added_on", "last_updated", "confidence", "owner", "privacy_type", "review_status", "status", "severity", "share_level", "tags", "description", "reactions", "my_reactions", ] fields = kwargs.get("fields", default_fields) # Check well-formattedness of descriptor IDs (which may have come from # arbitrary data on stdin). for id in ids: try: _ = int(id) except ValueError: raise Exception('Malformed descriptor ID "%s"' % id) # See also # https://developers.facebook.com/docs/threat-exchange/reference/apis/threat-descriptor/ # for available fields url = ( self._base_url + "/?access_token=" + self.api_token + "&ids=" + ",".join(ids) + "&fields=" + ",".join(fields) ) if showURLs: print("URL:") print(url) response = self.get_json_from_url(url) descriptors = [] for id, descriptor in response.items(): if not includeIndicatorInOutput: del descriptor["raw_indicator"] if verbose: print(json.dumps(descriptor)) if "tags" in fields: # tags is returned as a dict sturctred like: # "tags": { # "data": [ # { # "id": "03465026013486502", # "text": "uploaded_by_hma" # } # ] # }, # # Canonicalize the tag ordering and simplify the # structure to simply an array of tag-texts tags = descriptor.get("tags", {"data": []})["data"] descriptor["tags"] = sorted(tag["text"] for tag in tags) if descriptor.get("description") is None and "description" in fields: descriptor["description"] = "" descriptors.append(descriptor) return descriptors def get_threat_updates( self, privacy_group: int, *, start_time: t.Optional[int] = None, stop_time: t.Optional[int] = None, types: t.Iterable[str] = (), page_size: t.Optional[int] = None, fields: t.Optional[t.Iterable[str]] = None, decode_fn: t.Callable[[t.Any], t.Any] = None, ) -> _CursoredResponse: """Gets threat updates for the given privacy group.""" if fields is None: fields = ( "id", "indicator", "type", "creation_time", "last_updated", "should_delete", "tags", "status", "applications_with_opinions", ) params = { "access_token": self.api_token, "start_time": start_time, "stop_time": stop_time, "limit": page_size, "fields": ",".join(fields), } if types: params["types"] = ",".join(types) url = f"{self._base_url}/{privacy_group}/threat_updates/" return _CursoredResponse(self, url, params, decode_fn=decode_fn) def get_threat_privacy_groups_member( self, ) -> t.List[ThreatPrivacyGroup]: """ Returns a non-paginated list of all privacy groups the current app is a member of. """ fields = [ "id", "members_can_see", "members_can_use", "name", "description", "last_updated", "added_on", "threat_updates_enabled", ] url = self._get_graph_api_url( f"{self.app_id}/threat_privacy_groups_member", {"fields": ",".join(fields)} ) response = self.get_json_from_url(url) return [ThreatPrivacyGroup.from_graph_api_dict(d) for d in response["data"]] def get_threat_privacy_groups_owner( self, ) -> t.List[ThreatPrivacyGroup]: """ Returns a non-paginated list of all privacy groups the current app is a owner of. """ fields = [ "id", "members_can_see", "members_can_use", "name", "description", "last_updated", "added_on", "threat_updates_enabled", ] url = self._get_graph_api_url( f"{self.app_id}/threat_privacy_groups_owner", {"fields": ",".join(fields)} ) response = self.get_json_from_url(url) return [ThreatPrivacyGroup.from_graph_api_dict(d) for d in response["data"]] def _get_graph_api_url( self, sub_path: t.Optional[str], query_dict: t.Dict = {} ) -> str: """ Returns a threatexchange URL for a sub-path and a dictionary of query parameters. Automatically adds access_token to the query dictionary. """ if "access_token" not in query_dict: query_dict["access_token"] = self.api_token query = urllib.parse.urlencode(query_dict) base_url_parts = urllib.parse.urlparse(self._base_url) url_parts = urllib.parse.ParseResult( base_url_parts.scheme, base_url_parts.netloc, f"{base_url_parts.path}/{sub_path}", base_url_parts.params, query, base_url_parts.fragment, ) return urllib.parse.urlunparse(url_parts) def _validate_post_params_for_submit(self, postParams): """ Returns error message or None. This simply checks to see (client-side) if required fields aren't provided. """ if postParams.get(self._POST_PARAM_NAMES["descriptor_id"]) != None: return "descriptor_id must not be specified for submit." requiredFields = [ self._POST_PARAM_NAMES["indicator"], self._POST_PARAM_NAMES["type"], self._POST_PARAM_NAMES["description"], self._POST_PARAM_NAMES["share_level"], self._POST_PARAM_NAMES["status"], self._POST_PARAM_NAMES["privacy_type"], ] missingFields = [ fieldName if postParams.get(fieldName) == None else None for fieldName in requiredFields ] missingFields = [fieldName for fieldName in missingFields if fieldName != None] if len(missingFields) == 0: return None elif len(missingFields) == 1: return "Missing field %s" % missingFields[0] else: return "Missing fields %s" % ",".join(missingFields) def _validate_post_pararms_for_copy(self, postParams): """ Returns error message or None. This simply checks to see (client-side) if required fields aren't provided. """ if postParams.get(self._POST_PARAM_NAMES["descriptor_id"]) == None: return "Source-descriptor ID must be specified for copy." if postParams.get(self._POST_PARAM_NAMES["privacy_type"]) == None: return "Privacy type must be specified for copy." if postParams.get(self._POST_PARAM_NAMES["privacy_members"]) == None: return "Privacy members must be specified for copy." return None def react_to_threat_descriptor( self, descriptor_id, reaction, *, showURLs=False, dryRun=False ): """ Does a POST to the reactions API. See: https://developers.facebook.com/docs/threat-exchange/reference/reacting """ return self._postThreatDescriptor( "/".join( ( self._base_url, str(descriptor_id), f"?access_token={self.api_token}", ) ), {"reactions": reaction}, showURLs=showURLs, dryRun=dryRun, ) def remove_reaction_from_threat_descriptor( self, descriptor_id, reaction, *, showURLs=False, dryRun=False ) -> t.List: """ Does a POST to the reactions API. See: https://developers.facebook.com/docs/threat-exchange/reference/reacting """ return self._postThreatDescriptor( "/".join( ( self._base_url, str(descriptor_id), f"?access_token={self.api_token}", ) ), {"reactions_to_remove": reaction}, showURLs=showURLs, dryRun=dryRun, ) def upload_threat_descriptor(self, postParams, showURLs, dryRun): """ Does a single POST to the threat_descriptors endpoint. See also https://developers.facebook.com/docs/threat-exchange/reference/submitting """ errorMessage = self._validate_post_params_for_submit(postParams) if errorMessage != None: return [errorMessage, None, None] url = "/".join( (self._base_url, "threat_descriptors", f"?access_token={self.api_token}") ) return self._postThreatDescriptor(url, postParams, showURLs, dryRun) def copy_threat_descriptor(self, postParams, showURLs, dryRun): errorMessage = self._validate_post_pararms_for_copy(postParams) if errorMessage != None: return [errorMessage, None, None] # Get source descriptor sourceID = postParams["descriptor_id"] # Not valid for posting a new descriptor del postParams["descriptor_id"] sourceDescriptor = self.get_threat_descriptors([sourceID], showURLs=showURLs) sourceDescriptor = sourceDescriptor[0] # Mutate necessary fields newDescriptor = copy.deepcopy(sourceDescriptor) newDescriptor["indicator"] = sourceDescriptor["raw_indicator"] del newDescriptor["raw_indicator"] if "tags" in newDescriptor and newDescriptor["tags"] is None: del newDescriptor["tags"] # The shape is different between the copy-from data (mapping app IDs to # reactions) and the post data (just a comma-delimited string of owner-app # reactions). if "reactions" in newDescriptor: del newDescriptor["reactions"] # Take the source-descriptor values and overwrite any post-params fields # supplied by the caller. Note: Python's dict-update method keeps the old # value for a given field name when both old and new are present so we # invoke it seemingly 'backward'. # # Example: # * x = {'a': 1, 'b': 2, 'c': 3} # * y = {'a': 1, 'b': 9, 'd': 12} # * After y.update(x) then x is unchanged and y is # {'a': 1, 'b': 2, 'd': 12, 'c': 3} # # This means we want newDescriptor.update(postParams) newDescriptor.update(postParams) # Get rid of fields like last_upated from the source descriptor which # aren't valid for post postParams = {} for key, value in newDescriptor.items(): if self._POST_PARAM_NAMES.get(key) != None: postParams[key] = value return self.upload_threat_descriptor(postParams, showURLs, dryRun) def delete_threat_descriptor( self, descriptor_id, showURLs, dryRun ) -> t.List[t.Any]: url = ( self._base_url + "/" + str(descriptor_id) + "?access_token=" + self.api_token ) if showURLs: print() print("(DELETE) URL:") print(url) if dryRun: print("Not doing DELETE since --dry-run.") return [None, None, ""] try: with self._get_session() as session: return [None, None, session.delete(url).json()] except urllib.error.HTTPError as e: responseBody = json.loads(e.read().decode("utf-8")) return [None, e, responseBody] def _postThreatDescriptor(self, url, postParams, showURLs, dryRun): """Code-reuse for submit and update""" for key, value in postParams.items(): url += "&%s=%s" % (key, urllib.parse.quote(str(value))) if showURLs: print() print("(POST) URL:") print(url) if dryRun: print("Not doing POST since --dry-run.") return [None, None, ""] # Encode the inputs to the POST header = {"Content-Type": "text/json", "charset": "utf-8"} # This is a string data = urllib.parse.urlencode(postParams) # Turn it into a Python bytes object data = data.encode("ascii") # Do the POST try: with self._get_session() as session: return [None, None, session.post(url, data).json()] except urllib.error.HTTPError as e: responseBody = json.loads(e.read().decode("utf-8")) return [None, e, responseBody] def get_threat_descriptors_from_indicator( self, indicator_id: int, showURLs: bool = False ) -> t.List[t.Dict[str, t.Any]]: url = ( self._base_url + "/" + str(indicator_id) + "?fields=descriptors&access_token=" + self.api_token ) if showURLs: print("url =", url) response = self.get_json_from_url(url) return response["descriptors"]["data"]