hasher-matcher-actioner/hmalib/common/s3_adapters.py (400 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import collections
import csv
import functools
import warnings
import io
import json
import boto3
import csv
import codecs
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, field
import typing as t
from botocore.errorfactory import ClientError
from mypy_boto3_s3 import Client as S3Client
from mypy_boto3_s3.service_resource import Bucket
from threatexchange import threat_updates as tu
from threatexchange.cli.dataset.simple_serialization import HMASerialization
from threatexchange.descriptor import SimpleDescriptorRollup, ThreatDescriptor
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal
from hmalib.common.models.signal import (
ThreatExchangeSignalMetadata,
PendingThreatExchangeOpinionChange,
)
from hmalib import metrics
from hmalib.common.logging import get_logger
from hmalib.indexers.metadata import (
BaseIndexMetadata,
ThreatExchangeIndicatorIndexMetadata,
THREAT_EXCHANGE_SOURCE_SHORT_CODE,
)
logger = get_logger(__name__)
@functools.lru_cache(maxsize=None)
def get_dynamodb():
return boto3.resource("dynamodb")
@functools.lru_cache(maxsize=None)
def get_s3_client():
return boto3.client("s3")
# A hash row is a tuple of the hash_value and a sequence of metadata object.
# Remember, a single hash may have multiple metadata objects. One from
# threatexchange, one from banks, another one from threatexchange, but a
# different privacy group.
#
# Use mutable sequence instead of list here because we need subclasses of
# BaseIndexMetadata and t.List is not co-variant in mypy's grammar.
HashRowT = t.Tuple[str, t.MutableSequence[BaseIndexMetadata]]
# Signal types that s3_adapters should support
KNOWN_SIGNAL_TYPES: t.List[t.Type[SignalType]] = [VideoMD5Signal, PdqSignal]
@dataclass
class S3ThreatDataConfig:
SOURCE_STR = THREAT_EXCHANGE_SOURCE_SHORT_CODE
threat_exchange_data_bucket_name: str
threat_exchange_data_folder: str
@dataclass
class ThreatExchangeS3Adapter:
"""
Adapter for reading ThreatExchange data stored in S3. Concrete implementations
are for a specific indicator type such as PDQ
Assumes CSV file format
Should probably refactor and merge with ThreatUpdateS3Store for writes
"""
metrics_logger: metrics.lambda_with_datafiles
S3FileT = t.Dict[str, t.Any]
config: S3ThreatDataConfig
last_modified: t.Dict[str, str] = field(default_factory=dict)
def load_data(self) -> t.Dict[str, t.List[HashRowT]]:
"""
loads all data from all files in TE that are of the concrete implementations indicator type
returns a mapping from file name to list of rows
"""
logger.info("Retreiving %s Data from S3", self.file_type_str_name)
with metrics.timer(self.metrics_logger.download_datafiles):
# S3 doesnt have a built in concept of folders but the AWS UI
# implements folder-like functionality using prefixes. We follow
# this same convension here using folder name in a prefix search
s3_bucket_files = get_s3_client().list_objects_v2(
Bucket=self.config.threat_exchange_data_bucket_name,
Prefix=self.config.threat_exchange_data_folder,
)["Contents"]
logger.info("Found %d Files", len(s3_bucket_files))
typed_data_files = {
file["Key"]: self._get_file(file["Key"])
for file in s3_bucket_files
if file["Key"].endswith(self.indicator_type_file_extension)
}
logger.info(
"Found %d %s Files", len(typed_data_files), self.file_type_str_name
)
with metrics.timer(self.metrics_logger.parse_datafiles):
logger.info("Parsing %s Hash files", self.file_type_str_name)
typed_data = {
file_name: self._parse_file(**typed_data_file)
for file_name, typed_data_file in typed_data_files.items()
}
return typed_data
@property
def indicator_type_file_extension(self):
"""
What is the extension for files of this indicator type
eg. hash_pdq.te indicates PDQ files
"""
raise NotImplementedError()
@property
def file_type_str_name(self):
"""
What types of files does the concrete implementation correspond to
for logging only
"""
raise NotImplementedError()
@property
def indicator_type_file_columns(self):
"""
What are the csv columns when this type of data is stored in S3
"""
raise NotImplementedError()
def _get_file(self, file_name: str) -> t.Dict[str, t.Any]:
return {
"file_name": file_name,
"data_file": get_s3_client().get_object(
Bucket=self.config.threat_exchange_data_bucket_name, Key=file_name
),
}
def _parse_file(self, file_name: str, data_file: S3FileT) -> t.List[HashRowT]:
data_reader = csv.DictReader(
codecs.getreader("utf-8")(data_file["Body"]),
fieldnames=self.indicator_type_file_columns,
)
self.last_modified[file_name] = data_file["LastModified"].isoformat()
privacy_group = file_name.split("/")[-1].split(".")[0]
result: t.List[HashRowT] = []
for row in data_reader:
metadata = ThreatExchangeIndicatorIndexMetadata(
indicator_id=row["indicator_id"],
signal_value=row["hash"],
privacy_group=privacy_group,
)
if row["tags"]:
# note: these are the labels assigned by pytx in descriptor.py (NOT a 1-1 with tags on TE)
metadata.tags.update(row["tags"].split(" "))
result.append((row["hash"], [metadata]))
return result
class ThreatExchangeS3PDQAdapter(ThreatExchangeS3Adapter):
"""
Adapter for reading ThreatExchange PDQ data stored in CSV files S3
"""
@property
def indicator_type_file_extension(self):
return f"{PdqSignal.INDICATOR_TYPE.lower()}.te"
@property
def indicator_type_file_columns(self):
return ["hash", "indicator_id", "descriptor_id", "timestamp", "tags"]
@property
def file_type_str_name(self):
return "PDQ"
class ThreatExchangeS3VideoMD5Adapter(ThreatExchangeS3Adapter):
"""
Read ThreatExchange Video MD5 files in CSV from S3.
"""
@property
def indicator_type_file_extension(self):
# Hardcode because of indicator_type migration. This is extra weird
# because adapters do not write data, they only read data. One datafile
# read and write are both done via s3_adapters, this should no longer be
# necessary.
return f"hash_video_md5.te"
@property
def indicator_type_file_columns(self):
return ["hash", "indicator_id", "descriptor_id", "timestamp", "tags"]
@property
def file_type_str_name(self):
return "MD5"
class ThreatUpdateS3Store(tu.ThreatUpdatesStore):
"""
ThreatUpdatesStore, but stores files in S3 instead of local filesystem.
"""
CHECKPOINT_SUFFIX = ".checkpoint"
def __init__(
self,
privacy_group: int,
app_id: int,
s3_client: S3Client,
s3_bucket_name: str,
s3_te_data_folder: str,
data_store_table: str,
supported_signal_types: t.List[SignalType],
) -> None:
super().__init__(privacy_group)
self.app_id = app_id
self._cached_state: t.Optional[t.Dict] = None
self.s3_te_data_folder = s3_te_data_folder
self.data_store_table = data_store_table
self.supported_indicator_types = self._get_supported_indicator_types(
supported_signal_types
)
self.s3_client = s3_client
self.s3_bucket_name = s3_bucket_name
@classmethod
def indicator_type_str_from_signal_type(
cls, signal_type: t.Type[SignalType]
) -> str:
"""
This mapping is only necessary for types that are in the process of
being migrated. eg. VideoMD5.
"""
if signal_type == VideoMD5Signal:
return "HASH_VIDEO_MD5"
return getattr(signal_type, "INDICATOR_TYPE", None)
def _get_supported_indicator_types(
self, supported_signal_types: t.List[t.Type[SignalType]]
):
"""
For supported self.signal_types, get their corresponding indicator_types.
"""
indicator_types = []
for signal_type in supported_signal_types:
indicator_type = self.indicator_type_str_from_signal_type(signal_type)
if indicator_type:
indicator_types.append(indicator_type)
else:
warnings.warn(
f"SignalType: {signal_type} does not provide an indicator type."
)
return indicator_types
@property
def checkpoint_s3_key(self) -> str:
return f"{self.s3_te_data_folder}{self.privacy_group}{self.CHECKPOINT_SUFFIX}"
def get_privacy_group_prefix(self) -> str:
"""
Gets the prefix for all data files for self.privacy_group. Note that the
'.' is necessary. Otherwise for a case where privacy group ids are like
123 and 1234, a list_objects() call for 123 will return 123 and 1234
objects.
"""
return f"{self.s3_te_data_folder}{self.privacy_group}."
def get_s3_object_key(self, indicator_type) -> str:
"""
For self.privacy_group, creates an s3_key that stores data for
`indicator_type`. If changing, be mindful to change
get_signal_type_from_object_key() as well.
"""
extension = f"{indicator_type.lower()}.te"
return f"{self.get_privacy_group_prefix()}{extension}"
@classmethod
def get_signal_type_from_object_key(
cls, key: str
) -> t.Optional[t.Type[SignalType]]:
"""
Inverses get_s3_object_key. Given an object key (potentially generated
by this class), extracts the extension, compares that against known
signal_types to see if any of them have the same indicator_type and
returns that signal_type.
"""
# given s3://<foo_bucket>/threat_exchange_data/258601789084078.hash_pdq.te
# .te and everything other than hash_pdq can be ignored.
try:
_, extension, _ = key.rsplit(".", 2)
except ValueError:
# key does not meet the structure necessary. Impossible to determine
# signal_type
return None
for signal_type in KNOWN_SIGNAL_TYPES:
if signal_type.INDICATOR_TYPE.lower() == extension:
return signal_type
# Hardcode for HASH_VIDEO_MD5 because threatexchange's VideoMD5 still
# has HASH_MD5 as indicator_type
if extension == "hash_video_md5":
return VideoMD5Signal
return None
@property
def next_delta(self) -> tu.ThreatUpdatesDelta:
"""
IF YOU CHANGE SUPPORTED_SIGNALS, OLD CHECKPOINTS NEED TO BE INVALIDATED
TO GET THE NON-PDQ DATA!
"""
delta = super().next_delta
delta.types = self.supported_indicator_types
return delta
def reset(self):
super().reset()
self._cached_state = None
def _load_checkpoint(self) -> tu.ThreatUpdateCheckpoint:
"""Load the state of the threat_updates checkpoints from state directory"""
txt_content = read_s3_text(
self.s3_client, self.s3_bucket_name, self.checkpoint_s3_key
)
if txt_content is None:
logger.warning("No s3 checkpoint for %d. First run?", self.privacy_group)
return tu.ThreatUpdateCheckpoint()
checkpoint_json = json.load(txt_content)
ret = tu.ThreatUpdateCheckpoint(
checkpoint_json["last_fetch_time"],
checkpoint_json["fetch_checkpoint"],
)
logger.info(
"Loaded checkpoint for privacy group %d. last_fetch_time=%d fetch_checkpoint=%d",
self.privacy_group,
ret.last_fetch_time,
ret.fetch_checkpoint,
)
return ret
def _store_checkpoint(self, checkpoint: tu.ThreatUpdateCheckpoint) -> None:
txt_content = io.StringIO()
json.dump(
{
"last_fetch_time": checkpoint.last_fetch_time,
"fetch_checkpoint": checkpoint.fetch_checkpoint,
},
txt_content,
indent=2,
)
write_s3_text(
s3_client=self.s3_client,
bucket_name=self.s3_bucket_name,
key=self.checkpoint_s3_key,
txt_content=txt_content,
)
logger.info(
"Stored checkpoint for privacy group %d. last_fetch_time=%d fetch_checkpoint=%d",
self.privacy_group,
checkpoint.last_fetch_time,
checkpoint.fetch_checkpoint,
)
def _get_datafile_object_keys(self) -> t.Iterable[str]:
"""
Returns all non-checkpoint datafile objects for the current privacy group.
"""
return [
item["Key"]
for item in self.s3_client.list_objects_v2(
Bucket=self.s3_bucket_name, Prefix=self.get_privacy_group_prefix()
)["Contents"]
if not item["Key"].endswith(self.CHECKPOINT_SUFFIX)
]
def load_state(self, allow_cached=True):
if not allow_cached or self._cached_state is None:
# First, get a list of all files
all_datafile_keys = self._get_datafile_object_keys()
items = []
# Then for each datafile, append to items
for datafile in all_datafile_keys:
txt_content = read_s3_text(
self.s3_client, self.s3_bucket_name, datafile
)
signal_type = self.get_signal_type_from_object_key(datafile)
indicator_type = self.indicator_type_str_from_signal_type(signal_type)
if txt_content is None:
logger.warning("No TE state for %d. First run?", self.privacy_group)
elif indicator_type is None:
logger.warning(
"Could not identify indicator type for signal with type: %s. Will not process.",
signal_type.get_name(),
)
else:
csv.field_size_limit(65535) # dodge field size problems
for row in csv.reader(txt_content):
items.append(
HMASerialization(
row[0],
indicator_type,
row[1],
SimpleDescriptorRollup.from_row(row[2:]),
)
)
logger.info("%d rows loaded for %d", len(items), self.privacy_group)
# Do all in one assignment just in case of threads
self._cached_state = {item.key: item for item in items}
return self._cached_state
def _store_state(self, contents: t.Iterable["HMASerialization"]):
"""
Stores indicator data in CSV format with one file per indicator type.
"""
row_by_type: t.DefaultDict = collections.defaultdict(list)
for item in contents:
row_by_type[item.indicator_type].append(item)
# Discard all updates except PDQ
for indicator_type in row_by_type:
# Write one file per indicator type.
items = row_by_type.get(indicator_type, [])
with io.StringIO(newline="") as txt_content:
writer = csv.writer(txt_content)
writer.writerows(item.as_csv_row() for item in items)
write_s3_text(
s3_client=self.s3_client,
bucket_name=self.s3_bucket_name,
key=self.get_s3_object_key(indicator_type),
txt_content=txt_content,
)
logger.info(
"IndicatorType:%s, %d rows stored in PrivacyGroup %d",
indicator_type,
len(items),
self.privacy_group,
)
def _apply_updates_impl(
self,
delta: tu.ThreatUpdatesDelta,
post_apply_fn=lambda x: None,
) -> None:
state: t.Dict = {}
updated: t.Dict = {}
if delta.start > 0:
state = self.load_state()
for update in delta:
item = HMASerialization.from_threat_updates_json(
self.app_id, update.raw_json
)
if update.should_delete:
state.pop(item.key, None)
else:
state[item.key] = item
updated[item.key] = item
self._store_state(state.values())
self._cached_state = state
post_apply_fn(updated)
def get_new_pending_opinion_change(
self, metadata: ThreatExchangeSignalMetadata, new_tags: t.List[str]
):
# Figure out if we have a new opinion about this indicator and clear out a pending change if so
# python-threatexchange.descriptor.ThreatDescriptor.from_te_json guarentees there is either
# 0 or 1 opinion tags on a descriptor
opinion_tags = ThreatDescriptor.SPECIAL_TAGS
old_opinion = [tag for tag in metadata.tags if tag in opinion_tags]
new_opinion = [tag for tag in new_tags if tag in opinion_tags]
# If our opinion changed or if our pending change has already happend,
# set the pending opinion change to None, otherwise keep it unchanged
if old_opinion != new_opinion:
return PendingThreatExchangeOpinionChange.NONE
elif (
(
new_opinion == [ThreatDescriptor.TRUE_POSITIVE]
and metadata.pending_opinion_change
== PendingThreatExchangeOpinionChange.MARK_TRUE_POSITIVE
)
or (
new_opinion == [ThreatDescriptor.FALSE_POSITIVE]
and metadata.pending_opinion_change
== PendingThreatExchangeOpinionChange.MARK_FALSE_POSITIVE
)
or (
new_opinion == []
and metadata.pending_opinion_change
== PendingThreatExchangeOpinionChange.REMOVE_OPINION
)
):
return PendingThreatExchangeOpinionChange.NONE
else:
return metadata.pending_opinion_change
def post_apply(self, updated: t.Dict = {}):
"""
After the fetcher applies an update, check for matches
to any of the signals in data_store_table and if found update
their tags.
TODO: Additionally, if writebacks are enabled for this privacy group write back
INGESTED to ThreatExchange
"""
table = get_dynamodb().Table(self.data_store_table)
for update in updated.values():
row: t.List[str] = update.as_csv_row()
# example row format: ('<raw_indicator>', '<indicator-id>', '<descriptor-id>', '<time added>', '<space-separated-tags>')
# e.g (10736405276340','096a6f9...064f', '1234567890', '2020-07-31T18:47:45+0000', 'true_positive hma_test')
new_tags = row[4].split(" ") if row[4] else []
metadata = ThreatExchangeSignalMetadata.get_from_signal_and_privacy_group(
table,
int(row[1]), # indicator-id or signal-id
str(self.privacy_group),
)
if metadata:
new_pending_opinion_change = self.get_new_pending_opinion_change(
metadata, new_tags
)
else:
# If this is a new indicator without metadata there is nothing for us to update
return
metadata.tags = new_tags
metadata.pending_opinion_change = new_pending_opinion_change
# TODO: Combine 2 update functions into single function
if metadata.update_tags_in_table_if_exists(table):
logger.info(
"Updated Signal Tags in DB for indicator id: %s source: %s for privacy group: %d",
row[1],
S3ThreatDataConfig.SOURCE_STR,
self.privacy_group,
)
if metadata.update_pending_opinion_change_in_table_if_exists(table):
logger.info(
"Updated Pending Opinion in DB for indicator id: %s source: %s for privacy group: %d",
row[1],
S3ThreatDataConfig.SOURCE_STR,
self.privacy_group,
)
def read_s3_text(
s3_client: S3Client, bucket_name: str, key: str
) -> t.Optional[io.StringIO]:
byte_content = io.BytesIO()
try:
s3_client.download_fileobj(bucket_name, key, byte_content)
except ClientError as ce:
if ce.response["Error"]["Code"] != "404":
raise
return None
return io.StringIO(byte_content.getvalue().decode())
def write_s3_text(
s3_client: S3Client, bucket_name: str, key: str, txt_content: io.StringIO
) -> None:
byte_content = io.BytesIO(txt_content.getvalue().encode())
s3_client.upload_fileobj(byte_content, bucket_name, key)