hasher-matcher-actioner/hmalib/scripts/common/utils.py (409 lines of code) (raw):

#! /usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for hma scripts """ import os import json import base64 import requests import boto3 import subprocess import functools import typing as t from requests.adapters import HTTPAdapter from urllib.parse import urljoin from botocore.exceptions import ClientError # Defaults to override values in tf outputs or ENV API_URL = "" TOKEN = "" class HasherMatcherActionerAPI: """ Class for interfacing with a subset HMA API endpoints in python scripts. Supports future and existing storm, soak, and smoke test as well as being useful for debuging outside of the HMA UI and AWS Console. See hmalib/lambdas/api/ """ def __init__( self, api_url: str, api_token: str, transport_adapter: HTTPAdapter = None, ) -> None: self.api_url = api_url self.session = requests.Session() self.session.headers.update( { "content-type": "application/json", "Authorization": api_token, } ) if transport_adapter: self.add_transport_adapter(transport_adapter) def add_transport_adapter(self, transport_adapter: HTTPAdapter): self.session.mount(self.api_url, transport_adapter) def _get_request_url(self, api_path: str) -> str: return urljoin(self.api_url, api_path) def get(self, api_path: str = "root/"): response = self.session.get(self._get_request_url(api_path)) return response.json() def get_all_matches(self, api_path: str = "matches/"): response = self.session.get(self._get_request_url(api_path)) response.raise_for_status() return response.json().get("match_summaries", []) def submit_via_encoded_bytes( self, content_id: str, b64_file_contents: str, additional_fields: t.List[str] = [], ): payload = { "content_id": content_id, "content_type": "photo", "additional_fields": additional_fields, "content_bytes": b64_file_contents, } api_path: str = "submit/bytes/" response = self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode("utf-8"), ) response.raise_for_status() def submit_via_upload_put_url( self, content_id: str, file: t.BinaryIO, additional_fields: t.List[str] = [], ): payload = { "content_id": content_id, "content_type": "photo", "additional_fields": additional_fields, "file_type": "image/jpeg", } api_path: str = "submit/put-url/" response = self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) response.raise_for_status() response_json = response.json() put_response = requests.put( response_json["presigned_url"], data=file, headers={"content-type": "image/jpeg"}, ) put_response.raise_for_status() def submit_via_external_url( self, content_id: str, url: str, additional_fields: t.List[str] = [], ): """ Distinct from submit_via_upload_put_url(), It uses the URL only path and bypasses s3 completely. """ payload = { "content_id": content_id, "content_type": "photo", "additional_fields": additional_fields, "content_url": url, } api_path: str = "submit/url/" response = self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) response.raise_for_status() def submit_via_s3_object( self, content_id: str, bucket_name: str, object_key: str, content_type: str = "photo", additional_fields: t.List[str] = [], ): """ Submit to the API using a s3 object that api_root is authorized to read from """ payload = { "content_id": content_id, "content_type": content_type, "additional_fields": additional_fields, "bucket_name": bucket_name, "object_key": object_key, } api_path: str = "submit/s3/" response = self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) response.raise_for_status() def submit_via_hash( self, content_id: str, signal_value: str, signal_type: str = "pdq", content_type: str = "photo", url="", additional_fields: t.List[str] = [], ): """ Submit hash directy """ payload = { "content_id": content_id, "content_type": content_type, "additional_fields": additional_fields, "signal_value": signal_value, "signal_type": signal_type, "content_url": url, } api_path: str = "submit/hash/" response = self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) response.raise_for_status() @classmethod def sns_submit_via_external_url( cls, submit_topic_arn: str, content_id: str, url: str, additional_fields: t.List[str] = [], ): """ Distinct from submit_via_upload_put_url(), It uses the URL only path and bypasses s3 completely and submits via an SNS topic """ payload = { "content_id": content_id, "content_type": "photo", "additional_fields": additional_fields, "content_url": url, } sns_client = boto3.client("sns") response = sns_client.publish( TopicArn=submit_topic_arn, Message=json.dumps(payload), ) @classmethod def sns_submit_via_s3_object( cls, submit_topic_arn: str, content_id: str, bucket_name: str, object_key: str, content_type: str = "photo", additional_fields: t.List[str] = [], ): """ Submit to the SNS topic using a s3 object that submit_even_handler lambda is authorized to read from. """ payload = { "content_id": content_id, "content_type": content_type, "additional_fields": additional_fields, "bucket_name": bucket_name, "object_key": object_key, } sns_client = boto3.client("sns") response = sns_client.publish( TopicArn=submit_topic_arn, Message=json.dumps(payload), ) def get_content_hash_details( self, content_id: str, api_path: str = "content/hash/", ): return self.session.get( self._get_request_url(api_path), params={"content_id": content_id}, ).json() def get_content_action_history( self, content_id: str, api_path: str = "content/action-history/", ): response = self.session.get( self._get_request_url(api_path), params={"content_id": content_id}, ) return response.json().get("action_history", []) def get_content_matches( self, content_id: str, api_path: str = "matches/match/", ): response = self.session.get( self._get_request_url(api_path), params={"content_id": content_id}, ) return response.json().get("match_details", []) def get_dataset_configs( self, api_path: str = "datasets/", ): response = self.session.get(self._get_request_url(api_path)) return response.json().get("threat_exchange_datasets", []) def create_dataset_config( self, privacy_group_id: str, privacy_group_name: str, description: str = "", matcher_active: bool = True, fetcher_active: bool = False, write_back: bool = False, api_path: str = "datasets/create", ): payload = { "privacy_group_id": privacy_group_id, "privacy_group_name": privacy_group_name, "description": description, "fetcher_active": fetcher_active, "matcher_active": matcher_active, "write_back": write_back, } return self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) def get_actions( self, api_path: str = "actions/", ): response = self.session.get(self._get_request_url(api_path)) return response.json().get("actions_response", []) def create_action( self, name: str, config_subtype: str, fields: t.Dict[str, t.Any], api_path: str = "actions/", ): payload = { "name": name, "config_subtype": config_subtype, "fields": fields, } return self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) def delete_action( self, action_name: str, api_path: str = "actions/", ): return self.session.delete(self._get_request_url(api_path + action_name)) def get_action_rules( self, api_path: str = "action-rules/", ): response = self.session.get(self._get_request_url(api_path)) return response.json().get("action_rules", []) def create_action_rule( self, action_rule: t.Any, api_path: str = "action-rules/", ): payload = { "action_rule": action_rule, } return self.session.post( self._get_request_url(api_path), data=json.dumps(payload).encode(), ) def delete_action_rule( self, action_rule_name: str, api_path: str = "action-rules/", ): return self.session.delete(self._get_request_url(api_path + action_rule_name)) def get_matches_for_hash( self, signal_type: str, signal_value: str, api_path: str = "matches/for-hash/", ): params = { "signal_type": signal_type, "signal_value": signal_value, } response = self.session.get( self._get_request_url(api_path), params=params, ) return response.json().get("matches", []) def get_terraform_outputs( directory: str = "terraform", ) -> t.Dict[str, str]: """ Converts from the super verbose JSON output to a more natural string -> string map. """ cmd = ["terraform"] cmd.extend(["output", "-json"]) out = json.loads(subprocess.check_output(cmd, cwd=directory)) return {k: out[k]["value"] for k in out} def get_cached_terraform_outputs( path: str = "tmp.out", ) -> t.Dict[str, str]: """ Gets output if not already present at path. """ file_exists = os.path.exists(path) if not file_exists: with open(path, "w") as f: outputs = get_terraform_outputs() f.write(json.dumps(outputs)) return outputs with open(path) as f: return json.loads(f.read()) @functools.lru_cache(maxsize=None) def _get_cognito_client(): return boto3.client("cognito-idp") def get_token( username: str, pwd: str, pool_id: str, client_id: str, ): resp = _get_cognito_client().admin_initiate_auth( AuthFlow="ADMIN_USER_PASSWORD_AUTH", AuthParameters={"USERNAME": username, "PASSWORD": pwd}, UserPoolId=pool_id, ClientId=client_id, ) return resp def create_user( username: str, email: str, pwd: str, pool_id: str, client_id: str, ): _get_cognito_client().admin_create_user( UserPoolId=pool_id, Username=username, UserAttributes=[ {"Name": "email_verified", "Value": "True"}, {"Name": "email", "Value": email}, ], ForceAliasCreation=False, MessageAction="SUPPRESS", ) _get_cognito_client().admin_set_user_password( UserPoolId=pool_id, Username=username, Password=pwd, Permanent=True, ) def delete_user( username: str, pwd: str, pool_id: str, client_id: str, ): try: resp = _get_cognito_client().admin_delete_user( UserPoolId=pool_id, Username=username, ) except ClientError as err: # if the user is not found. pass def get_default_user_name(prefix: str): return f"{prefix}testuser" def get_auth_from_env( token_default: str = TOKEN, prompt_for_token: bool = False, ): token = os.environ.get( "HMA_TOKEN", token_default, ) if not token: if prompt_for_token: print("Need an access token to authenticate.") token = input("Enter token: ") else: print("Authentication requires HMA_TOKEN be present in ENV.") exit() return token if __name__ == "__main__": # If you want hard code tests for methods you can do so here: # Testing HasherMatcherActionerAPI # Since the API requries a deployed instance the majority of # testing needs to be done manually. See below. print("Manual Test of API Request Methods:") tf_outputs = get_terraform_outputs() api_url = tf_outputs["api_url"] token = get_auth_from_env() api = HasherMatcherActionerAPI( api_url, token, ) # e.g. if auth is correct the following command should print: # "{'message': 'Hello World, HMA'}" print(api.get())