python-threatexchange/threatexchange/cli/dataset/simple_serialization.py (132 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import collections import csv import pathlib import re import typing as t from threatexchange.fb_threatexchange import threat_updates from threatexchange.fb_threatexchange.descriptor import SimpleDescriptorRollup _EXTENSION = ".te" # TODO - merge SimpleDescriptorRollup here class CliIndicatorSerialization(threat_updates.ThreatUpdateSerialization): """A short compact serialization optimized for the CLI""" def __init__( self, indicator_type: str, indicator: str, rollup: SimpleDescriptorRollup, ): self.indicator_type = indicator_type self.indicator = indicator self.rollup = rollup @property def key(self): return f"{self.indicator_type}.{self.indicator}" def as_csv_row(self) -> t.Tuple: """As a simple record type for the threatexchange CLI cache""" return (self.indicator,) + self.rollup.as_row() @classmethod def from_threat_updates_json(cls, app_id, te_json): return cls( te_json["type"], te_json["indicator"], SimpleDescriptorRollup.from_threat_updates_json(app_id, te_json), ) @classmethod def te_threat_updates_fields(cls): return SimpleDescriptorRollup.te_threat_updates_fields() # ToDo this violates Liskov but is already used in Prod and will require a larger refactor @classmethod def store( # type: ignore cls, state_dir: pathlib.Path, contents: t.Iterable["CliIndicatorSerialization"] ) -> t.List[pathlib.Path]: # Stores in multiple files split by indicator type row_by_type = collections.defaultdict(list) for item in contents: row_by_type[item.indicator_type].append(item) ret = [] for threat_type, items in row_by_type.items(): path = state_dir / f"simple.{threat_type}{_EXTENSION}" ret.append(path) with path.open("w", encoding="utf-8", newline="") as f: writer = csv.writer(f) for item in items: writer.writerow(item.as_csv_row()) return ret @classmethod def load(cls, state_dir: pathlib.Path) -> t.Iterable["CliIndicatorSerialization"]: """Load this serialization from the state directory""" ret = [] pattern = r"simple\.([^.]+)" + re.escape(_EXTENSION) for path in state_dir.glob(f"simple.*{_EXTENSION}"): match = re.match(pattern, path.name) if not match or not path.is_file(): continue indicator_type = match.group(1) # Violate your warranty with class state! Not threadsafe! csv.field_size_limit(path.stat().st_size) # dodge field size problems with path.open("r", encoding="utf-8", newline="") as f: for row in csv.reader(f): ret.append( cls( indicator_type, row[0], SimpleDescriptorRollup.from_row(row[1:]), ) ) return ret class HMASerialization(CliIndicatorSerialization): """ A Serialization for HMA Similar to CliIndicatorSerialization but with Indicator ID. We also include the First Descriptor ID. The logic to determine which ID this is can be found in the SimpleDescriptorRollup """ def __init__( self, indicator: str, indicator_type: str, indicator_id: str, rollup: SimpleDescriptorRollup, ): self.indicator_id = indicator_id self.indicator_type = indicator_type self.indicator = indicator self.rollup = rollup def as_csv_row(self) -> t.Tuple: """indicator details and descriptor rollup without descriptor ID""" return (self.indicator, self.indicator_id) + self.rollup.as_row() @classmethod def from_threat_updates_json(cls, app_id, te_json): return cls( te_json["indicator"], te_json["type"], te_json["id"], SimpleDescriptorRollup.from_threat_updates_json(app_id, te_json), ) @classmethod def from_csv_row( cls, row: t.List[t.Any], indicator_type: str ) -> "HMASerialization": return cls( str(row[0]), indicator_type, str(row[1]), SimpleDescriptorRollup.from_row(row[2:]), ) @classmethod def load(cls, state_dir: pathlib.Path) -> t.Iterable["HMASerialization"]: """Load this serialization from the state directory""" ret = [] pattern = r"simple\.([^.]+)" + re.escape(_EXTENSION) for path in state_dir.glob(f"simple.*{_EXTENSION}"): match = re.match(pattern, path.name) if not match or not path.is_file(): continue indicator_type = match.group(1) # Violate your warranty with class state! Not threadsafe! csv.field_size_limit(path.stat().st_size) # dodge field size problems with path.open("r", newline="") as f: for row in csv.reader(f): ret.append(cls.from_csv_row(row, indicator_type)) return ret if __name__ == "__main__": # Test Serialize Deserialize indicator = "indicator" indicator_id = "12345" first_descriptor_id = 6789 added_on = "today" labels = {"tag1", "tag2"} ser = HMASerialization( indicator, "HASH_PDQ", indicator_id, SimpleDescriptorRollup(first_descriptor_id, added_on, labels), ) serdeser = HMASerialization.from_csv_row(list(ser.as_csv_row()), "HASH_PDQ") if ser.as_csv_row() == serdeser.as_csv_row(): print("Serialization worked correctly") else: print("Serialization failed")