hasher-matcher-actioner/hmalib/scripts/cli/storm.py (206 lines of code) (raw):
#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Storm/stress test a deployed instance of HMA
Besides deploying you will likely want a place to set off the test (ec2 instance)
Run
- Submit content matching `media`, `count` times using `submit_method`
Example usage for each submission type
```
$ hmacli storm -m "<filepath>" -c 100 -v upload
$ hmacli storm -m "<filepath>" -c 100 -v bytes
$ hmacli storm -m "<signed-url>" -c 100 -v url
$ hmacli storm -m "<bucket>:<key>" -c 100 -v s3
$ hmacli storm -m "<pdq-hash>" -c 100 -v hash
```
"""
import os
import sys
import base64
import argparse
import uuid
import datetime
import requests
import urllib3
import concurrent.futures
import functools
import socket
import typing as t
from time import perf_counter
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import hmalib.scripts.cli.command_base as base
import hmalib.scripts.common.utils as utils
# "/" does not work with react router preventing the content submission details from rendering.
# However it can be used here for easier clean up between storms that do not need the UI.
ID_SEPARATOR = "-"
# Value to incorperate into content_id as an attempt to get greater prefix optimizations (ymmv).
# bucket_prefix_known = int(socket.gethostname()[-2:]) % 5
@functools.lru_cache(maxsize=None)
def _get_b64_contents(filepath):
with open(filepath, "rb") as file:
return str(base64.b64encode(file.read()), "utf-8")
class StormCommand(base.Command, base.NeedsAPIAccess):
"""
Start a storm/stress test on a deployed HMA instance. Will trigger lambdas and cost some $$
"""
@classmethod
def init_argparse(cls, ap: argparse.ArgumentParser) -> None:
ap.add_argument(
"--media",
"-m",
help="Media to be submitted based on method. e.g. file path of photo to upload (upload/bytes), url, hash, or bucket:key (s3).",
required=True,
)
ap.add_argument(
"submit_method",
choices=["upload", "bytes", "url", "s3", "hash", "sns-s3", "sns-url"],
help="Method of submission to use in storm",
default="upload",
)
ap.add_argument(
"--count",
"-c",
help="Number of times to submit media.",
default=1,
)
ap.add_argument(
"--verbose",
"-v",
action="store_true",
help="Print progress and percentile times for submission sets.",
)
ap.add_argument(
"--retry",
"-r",
action="store_true",
help="Have request retry (with exponential backoff) if errors are encoutered.",
)
ap.add_argument(
"--sns_topic",
help="ARN of SNS Topic to submit to, required when submit_method=[sns-s3,sns-url]",
default="",
)
@classmethod
def get_name(cls) -> str:
"""The display name of the command"""
return "storm"
@classmethod
def get_help(cls) -> str:
"""The short help of the command"""
return "run a storm/stress test"
def __init__(
self,
media: str,
submit_method: str,
count: str,
verbose: bool = False,
retry: bool = False,
sns_topic: str = "",
) -> None:
self.media = media
self.submit_method = submit_method
self.count = int(count)
self.verbose = verbose
self.retry = retry
self.sns_topic = sns_topic
# hostnames are especially useful when storming from more than one instance.
self.hostname = socket.gethostname()
def execute(self, api: utils.HasherMatcherActionerAPI) -> None:
sent_message_count = 0
jobs = []
execution_times = []
if self.retry:
retry_strategy = Retry(
total=5,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
method_whitelist=[
"GET",
"PUT",
"POST",
], # Including POST as the API should not perform an insert if an error is returned
)
api.add_transport_adapter(HTTPAdapter(max_retries=retry_strategy))
if self.submit_method in ["sns-s3", "sns-url"] and not self.sns_topic:
print(
f"submit_method: {self.submit_method} require's the sns topic arn be provided to `--sns_topic`"
)
sys.exit(1)
# Need to compute script level RequestsPerSecond so that we can estimate
# benchmark performance. For that, storing the start time every 200
# requests and reporting the QPS between that and current.
with concurrent.futures.ThreadPoolExecutor(max_workers=48) as executor:
if self.verbose:
print("Started adding tasks to executor")
chunk_start_time_200 = perf_counter()
while sent_message_count < self.count:
jobs.append(
executor.submit(
self._submit,
api,
)
)
sent_message_count += 1
if self.verbose:
print(f"{sent_message_count} requests prepared", end="\r")
if self.verbose:
print("\nDone adding tasks to executor")
for i, completed_future in enumerate(concurrent.futures.as_completed(jobs)):
execution_times.append(completed_future.result())
if self.verbose:
progress_report_string = f"{i} of {self.count} sent!"
rps_report_string = f"Current 200 chunk has QPS of {(i % 200) / (perf_counter() - chunk_start_time_200)}"
# Report progress and RPS
print(f"{progress_report_string}! {rps_report_string}", end="\r")
if i % 200 == 0:
# Reset chunk start time
chunk_start_time_200 = perf_counter()
if self.verbose:
print(f"\nSent all {self.count} submissions.")
# Compute some beginner stats.
execution_times = sorted(execution_times)
print(
f"""Percentiles in ms:
p75: {execution_times[int(len(execution_times)*0.75)]}
p95: {execution_times[int(len(execution_times)*0.95)]}
p99: {execution_times[int(len(execution_times)*0.99)]}
"""
)
def _get_submission_id(self, file_name: str = "", sep=ID_SEPARATOR):
if file_name:
file_name = f"-{file_name}"
return f"storm-{self.submit_method}-{self.hostname}{sep}{datetime.date.today().isoformat()}{sep}{str(uuid.uuid4())}{file_name}"
def _submit(
self,
api: utils.HasherMatcherActionerAPI,
) -> int:
"""
Submit a single time (method depends on `submit_method`) and return the time it took in ms.
"""
file_name = ""
if self.submit_method in ["upload", "bytes"]:
# get filename from filepath
file_name = os.path.split(self.media)[-1]
content_id = self._get_submission_id(file_name)
additional_fields: t.List[str] = []
start_time = perf_counter()
try:
if self.submit_method == "upload":
with open(self.media, "rb") as file:
api.submit_via_upload_put_url(
content_id=content_id,
file=file,
additional_fields=additional_fields,
)
elif self.submit_method == "bytes":
api.submit_via_encoded_bytes(
content_id=content_id,
b64_file_contents=_get_b64_contents(self.media),
additional_fields=additional_fields,
)
elif self.submit_method == "url":
api.submit_via_external_url(
content_id=content_id,
url=self.media,
additional_fields=additional_fields,
)
elif self.submit_method == "hash":
api.submit_via_hash(
content_id=content_id,
signal_value=self.media,
additional_fields=additional_fields,
)
elif self.submit_method == "s3":
bucket, key = self.media.split(":", 1)
api.submit_via_s3_object(
content_id=content_id,
bucket_name=bucket,
object_key=key,
additional_fields=additional_fields,
)
elif self.submit_method == "sns-s3":
bucket, key = self.media.split(":", 1)
api.sns_submit_via_s3_object(
submit_topic_arn=self.sns_topic,
content_id=content_id,
bucket_name=bucket,
object_key=key,
additional_fields=additional_fields,
)
elif self.submit_method == "sns-url":
api.sns_submit_via_external_url(
submit_topic_arn=self.sns_topic,
content_id=content_id,
url=self.media,
additional_fields=additional_fields,
)
# api methods call raise_for_status to check for error
except (
urllib3.exceptions.MaxRetryError,
requests.exceptions.HTTPError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
) as err:
print("Error:", err)
# convert seconds to miliseconds.
return int((perf_counter() - start_time) * 1000)