python-threatexchange/threatexchange/fb_threatexchange/threat_updates.py (258 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Helpers and wrappers around the /threat_updates endpoint.
"""
import json
import os
import pathlib
import time
import typing as t
from dataclasses import dataclass
from .api import ThreatExchangeAPI, _CursoredResponse
from .descriptor import SimpleDescriptorRollup
class ThreatUpdateSerialization:
"""
A wrapper for converting records fetched from /threat_updates
"""
@property
def key(self):
"""This should either be the indicator type+string or id"""
raise NotImplementedError
@classmethod
def from_threat_updates_json(cls, app_id: int, te_json):
raise NotImplementedError
@classmethod
def te_threat_updates_fields(cls):
"""Which &fields= arguments need to be passed for this serialization"""
raise NotImplementedError
@dataclass
class ThreatUpdateJSON(ThreatUpdateSerialization):
"""A thin wrapper around the /threat_updates API return"""
raw_json: t.Dict[str, t.Any]
@property
def should_delete(self) -> bool:
"""This record is a tombstone, and we should delete our copy"""
# This should just be should_delete only, but see
# https://github.com/facebook/ThreatExchange/issues/834
return self.raw_json["should_delete"] or "descriptors" not in self.raw_json
@property
def key(self):
return self.id
@property
def id(self) -> int:
return int(self.raw_json["id"])
@property
def indicator(self) -> str:
return self.raw_json["indicator"]
@property
def threat_type(self) -> str:
return self.raw_json["type"]
@property
def time(self) -> int:
"""The time of the update"""
return int(self.raw_json["last_updated"])
@classmethod
def from_threat_updates_json(cls, app_id, te_json):
return cls(te_json)
@classmethod
def te_threat_updates_fields(cls) -> t.Tuple[str, ...]:
# Could also return empty here, but this set is useful for basically
# any serialization, and so it makes sense to fetch it for future processing
# (even though it's verbose)
return SimpleDescriptorRollup.te_threat_updates_fields()
class ThreatUpdatesDelta:
"""
A class for tracking a raw stream of /threat_updates
Any integration with ThreatExchange involves the creation of a local copy
of the data. /threat_updates sends changes to that data in either the form
of an insert/update or a delete.
A delta is a stream of updates and deletes, which when applied to an
existing database will give you the current set of live records.
As a parallelization trick, if you need to fetch between t1 and t3,
you can pick a point between them, t2, and fetch [t1, t2) and [t2, t3)
simulatenously, and the merging of the two is guaranteed to be the same
as [t1, t3). The split() and merge() commands aid with this operation
"""
def __init__(
self,
privacy_group: int,
start: int = 0,
end: t.Optional[int] = None,
types: t.Iterable[str] = (),
) -> None:
self.privacy_group = privacy_group
self.updates: t.List = []
self.current = start
self.start = start
self.end = end
self.types = list(types)
self._cursor: t.Optional[_CursoredResponse] = None
@property
def done(self) -> bool:
"""Has this delta fetched its entire assigned range?"""
return bool(self.end and self.end <= self.current)
def __bool__(self):
return self.done or bool(self.updates)
def __iter__(self):
return iter(self.updates)
def merge(self, delta: "ThreatUpdatesDelta") -> None:
"""
Merge the earlier delta (this object) with a later delta.
If you have
t1 ---> t2 ---> t3
t1.merge(t2).merge(t3) is valid, and will give you a range from t1-t3
"""
if not self.done or self.end != delta.start:
raise ValueError("unchecked merge!")
self.updates.extend(delta.updates)
self.current = delta.current
self.end = delta.end
def one_fetch(self, api: ThreatExchangeAPI):
"""
Do a single fetch from ThreatExchange and store the results.
One fetch only, please.
"""
if self.done:
return
now = time.time()
if not self._cursor:
self._cursor = api.get_threat_updates(
self.privacy_group,
page_size=500,
start_time=self.start,
stop_time=self.end,
types=self.types,
fields=ThreatUpdateJSON.te_threat_updates_fields(),
decode_fn=ThreatUpdateJSON,
)
for update in self._cursor.next():
self.updates.append(ThreatUpdateJSON(update.raw_json))
# Is supposed to be strictly increasing
self.current = max(update.time, self.current)
if self._cursor.done:
if not self.end:
self.end = int(now)
self.current = self.end
return self._cursor.data
def split(
self, n: int
) -> t.Tuple["ThreatUpdatesDelta", t.List["ThreatUpdatesDelta"]]:
"""Split this delta into n deltas of roughly even size"""
tar = self.end or time.time()
diff = int((tar - self.current) // (n + 1))
if diff <= 0:
return self, []
end = self.end
prev = self
new_deltas = []
for i in range(n - 1):
new_start = prev.start + diff
new_deltas.append(
ThreatUpdatesDelta(
self.privacy_group,
new_start,
)
)
prev.end = end
return self, new_deltas
def incremental_sync_from_threatexchange(
self,
api: ThreatExchangeAPI,
*,
limit: t.Optional[int] = None,
progress_fn=lambda x: None,
) -> None:
"""
Fetch from threat_updates to get a more up-to-date copy of the data.
"""
# TODO actually implement fancy threading logic
# alternative - instead make the API give hints about where to start
# fetches, which will mean that fancy threading logic will be simpler
while not self.done:
for update in self.one_fetch(api):
progress_fn(update)
if limit is not None:
limit -= 1
if limit <= 0:
return
class ThreatUpdateCheckpoint(t.NamedTuple):
"""
State about the progress of a /threat_updates-backed state.
If a client does not resume tailing the threat_updates endpoint fast enough,
deletion records will be removed, making it impossible to determine which
records should be retained without refetching the entire dataset from scratch.
The API implementation will retain for 90 days:
https://developers.facebook.com/docs/threat-exchange/reference/apis/threat-updates/
"""
# See docstring about tailing fast enough
DEFAULT_REFETCH_SEC: int = 3600 * 24 * 85 # 85 days
# When was the last time we started or the furthest we've seen,
# to check against the store getting too stale
last_fetch_time: int = 0
# Where should we resume from?
fetch_checkpoint: int = 0
def get_updated(self, delta: ThreatUpdatesDelta) -> "ThreatUpdateCheckpoint":
# If starting from 0, this is the first fetch, in which case the first update
# means that we fetched now.
last_fetch_time = self.last_fetch_time
if last_fetch_time == 0:
last_fetch_time = int(time.time())
return ThreatUpdateCheckpoint(
last_fetch_time=max(last_fetch_time, delta.current),
fetch_checkpoint=max(self.fetch_checkpoint, delta.current),
)
@property
def stale(self):
"""Is this checkpoint so old as to be invalid?"""
return self.last_fetch_time + self.DEFAULT_REFETCH_SEC < time.time()
class ThreatUpdatesStore:
"""
A wrapper for ThreatIndicator records for a single Collaboration
There is a unique file for each combination of:
* IndicatorType
* PrivacyGroup
The contents of file does not strip anything from the API
response, so can potentially contain a lot of data.
"""
def __init__(
self,
privacy_group: int,
) -> None:
self.privacy_group = privacy_group
self.checkpoint: t.Optional[ThreatUpdateCheckpoint] = None
@property
def fetch_checkpoint(self):
return self.checkpoint.fetch_checkpoint
def reset(self) -> None:
"""Toss old state and begin anew"""
self.checkpoint = ThreatUpdateCheckpoint()
@property
def next_delta(self) -> ThreatUpdatesDelta:
"""Return the next delta that should be applied"""
return ThreatUpdatesDelta(
self.privacy_group,
self.checkpoint.fetch_checkpoint if self.checkpoint else 0,
None,
)
def load_checkpoint(self) -> None:
self.checkpoint = self._load_checkpoint()
def _load_checkpoint(self) -> ThreatUpdateCheckpoint:
"""Load the state of the threat_updates checkpoints"""
raise NotImplementedError
def _store_checkpoint(self, checkpoint: ThreatUpdateCheckpoint) -> None:
"""Save the state of the threat_updates checkpoints after a succesful apply"""
raise NotImplementedError
def _apply_updates_impl(
self,
delta: ThreatUpdatesDelta,
post_apply_fn=lambda x: None,
) -> None:
"""Apply delta to state and store it"""
raise NotImplementedError
@property
def stale(self) -> bool:
"""Is this state so old that it might be invalid?"""
return self.checkpoint.stale if self.checkpoint else False
def apply_updates(
self,
delta: ThreatUpdatesDelta,
post_apply_fn=lambda x: None,
) -> None:
"""Merge updates to the data store"""
if delta.start != 0:
assert (
self.checkpoint and delta.start <= self.checkpoint.fetch_checkpoint
), "gap in delta record"
assert not self.stale, "attempted to apply stale delta"
# It's possible the fetch completed but has no records
if delta.updates:
self._apply_updates_impl(delta, post_apply_fn)
if self.checkpoint:
self.checkpoint = self.checkpoint.get_updated(delta)
self._store_checkpoint(self.checkpoint)
class ThreatUpdateFileStore(ThreatUpdatesStore):
"""
A simple file storage (in lieu of DB) with in-memory merge
"""
def __init__(
self,
state_dir: pathlib.Path,
privacy_group: int,
app_id: int,
*,
serialization=ThreatUpdateJSON,
) -> None:
super().__init__(privacy_group)
self.path = state_dir
self.app_id = app_id
self._serialization = serialization
self._cached_state: t.Optional[t.Dict[str, t.Any]] = None
@property
def checkpoint_file(self) -> pathlib.Path:
return self.path / f"{self.privacy_group}.threat_updates.checkpoint"
def reset(self):
super().reset()
if self._cached_state:
self._cached_state.clear()
def _load_checkpoint(self) -> ThreatUpdateCheckpoint:
"""Load the state of the threat_updates checkpoints from state directory"""
if not self.checkpoint_file.exists():
return ThreatUpdateCheckpoint()
with self.checkpoint_file.open("r") as f:
checkpoint_json = json.load(f)
return ThreatUpdateCheckpoint(
checkpoint_json["last_fetch_time"],
checkpoint_json["fetch_checkpoint"],
)
def _store_checkpoint(self, checkpoint: ThreatUpdateCheckpoint) -> None:
with self.checkpoint_file.open("w") as f:
json.dump(
{
"last_fetch_time": checkpoint.last_fetch_time,
"fetch_checkpoint": checkpoint.fetch_checkpoint,
},
f,
indent=2,
)
def load_state(self, allow_cached=True):
if not self.path.exists():
return {}
if not allow_cached or self._cached_state is None:
self._cached_state = {
item.key: item for item in self._serialization.load(self.path)
}
return self._cached_state
def _apply_updates_impl(
self, delta: ThreatUpdatesDelta, post_apply_fn=lambda x: None
) -> None:
os.makedirs(self.path, exist_ok=True)
state = {}
if delta.start > 0:
state = self.load_state()
for update in delta:
item = self._serialization.from_threat_updates_json(
self.app_id, update.raw_json
)
if update.should_delete:
state.pop(item.key, None)
else:
state[item.key] = item
self._cached_state = state
self._serialization.store(self.path, state.values())