hasher-matcher-actioner/hmalib/common/models/count.py (112 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Refer to hmalib.lambdas.ddb_stream_counter.lambda_handler's doc string to understand how these models are used. """ import typing as t from collections import defaultdict from boto3.dynamodb.conditions import Key from mypy_boto3_dynamodb.service_resource import Table class BaseCount: """ Defines a single count value. """ def get_pkey(self) -> str: """ Get partition key for this count. """ raise NotImplementedError def get_skey(self) -> str: """ Get sort key for this count. """ raise NotImplementedError def get_value(self, table: Table) -> int: """ Get current value for the counter. """ return t.cast( int, table.get_item(Key={"PK": self.get_pkey(), "SK": self.get_skey()}) .get("Item", {}) .get("CurrentCount", 0), ) def inc(self, table: Table, by=1): """ Increment count. Default by 1, unless specified. """ table.update_item( Key={"PK": self.get_pkey(), "SK": self.get_skey()}, UpdateExpression="SET CurrentCount = if_not_exists(CurrentCount, :zero) + :by", ExpressionAttributeValues={":by": by, ":zero": 0}, ) def dec(self, table: Table, by=1): """ Increment count. Default by 1, unless specified. """ table.update_item( Key={"PK": self.get_pkey(), "SK": self.get_skey()}, UpdateExpression="SET CurrentCount = if_not_exists(CurrentCount, :zero) - :by", ExpressionAttributeValues={":by": by, ":zero": 0}, ) class AggregateCount(BaseCount): """ A "total" count. It is possible for some entities to have TBD hourly as well as aggregate counts. eg. Give me all matches today, but also keep track of the total number of matches we have ever done. """ class PipelineNames: # How many pieces of content were submitted? submits = "hma.pipeline.submits" # How many pieces of content created a hash record? hashes = "hma.pipeline.hashes" # How many match object recorded? matches = "hma.pipeline.matches" def __init__(self, of: str): self.of = of @staticmethod def _get_pkey_for_aggregate(of: str) -> str: return f"aggregate#{of}" @staticmethod def _get_skey_for_aggregate() -> str: return "aggregate_count" def get_pkey(self) -> str: return self._get_pkey_for_aggregate(self.of) def get_skey(self) -> str: return self._get_skey_for_aggregate() class ParameterizedCount(BaseCount): """ Allows you to do aggregate counts, with a parameter. An example would be matches per privacy group. At this point, only supports a single parameter value. Think about it as ParameterizedCount(of="hma.pipeline.matches", by="privacy_group"), value="4567896456789000976")) or ParameterizedCount(of="hma.pipeline.hashes", by="content_type", value="photo") or ParameterizedCount(of="hma.pipeline.hashes", by="signal_type", value="pdq") """ SKEY_PREFIX = "val#" SKEY_PREFIX_LENGTH = len(SKEY_PREFIX) def __init__(self, of: str, by: str, value: str, cached_value: int = None): """ You may provide a cached value if this object is getting retrieved from the database. Note, this does not in any way change the actual value in the database. It only saves a database call if you are using get_value() immediately after. """ self.of = of self.by = by self.value = value self._cached_value = cached_value def get_value(self, table: Table) -> int: """ If cached_value is set to a non-None value, return it, else make a database call to get the answer. This is useful when you are getting a list of parameterized counts using `ParameterizedCount.get_all()` """ if self._cached_value: return self._cached_value return super().get_value(table) @classmethod def get_all(cls, of: str, by: str, table: Table) -> t.List["ParameterizedCount"]: return [ cls( of, by, value=t.cast(str, item.get("SK"))[ cls.SKEY_PREFIX_LENGTH : ], # strip the "val#" portion cached_value=t.cast(int, item.get("CurrentCount", 0)), ) for item in table.query( ScanIndexForward=True, KeyConditionExpression=Key("PK").eq( cls._get_pkey_for_parameterized(of, by) ), )["Items"] ] @staticmethod def _get_pkey_for_parameterized(of: str, by: str) -> str: return f"parameterized#{of}#by#{by}" @classmethod def _get_skey_for_parameterized(cls, by: str, value: str) -> str: return f"{cls.SKEY_PREFIX}{value}" def get_pkey(self) -> str: return self._get_pkey_for_parameterized(self.of, self.by) def get_skey(self) -> str: return self._get_skey_for_parameterized(self.by, self.value) class CountBuffer: """ A buffer that for increments to the variety of count types. Must call buffer.flush() at the end to flush everything to ddb. buffer = CountBuffer(ddb_table) buffer.inc_aggregate("hma.pipeline.matches") buffer.inc_parameterized("hma.pipeline.submit", by="content_type", value="photo") """ def __init__(self, table: Table): self.table = table self.aggregate_deltas: t.DefaultDict = defaultdict(lambda: 0) self.parameterized_deltas: t.DefaultDict = defaultdict(lambda: 0) def inc_aggregate(self, of: str): """ Increment an aggregate counter. """ self.aggregate_deltas[of] += 1 def dec_aggregate(self, of: str): """ Decrement an aggregate counter. """ self.aggregate_deltas[of] -= 1 def inc_parameterized(self, of: str, by: str, value: str): """ Increment a parameterized counter. eg. buffer.inc_parameterized("hma.pipeline.submit", by="content_type", value="photo") """ self.parameterized_deltas[(of, by, value)] += 1 def dec_parameterized(self, of: str, by: str, value: str): """ Decrement a parameterized counter. eg. buffer.dec_parameterized("hma.pipeline.submit", by="content_type", value="photo") """ self.parameterized_deltas[(of, by, value)] -= 1 def flush(self): """ Write all counters remaining in the buffer. Since we do not autoflush yet, this may take some time. TODO: Make this into batch calls to dynamodb so it is performant. Right now, we iterate through all increments and make individual calls to dynamodb. This is partially because BaseCount defines inc() method. Can this be extracted out such that instead of doing one ddb write per BaseCount, we can batch the DDB writes and do a single BatchWriteItem call? https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_BatchWriteItem.html """ for name, increment_by in self.aggregate_deltas.items(): if increment_by > 0: AggregateCount(str(name)).inc(self.table, increment_by) elif increment_by < 0: AggregateCount(str(name)).dec(self.table, abs(increment_by)) # reset flushed buffer self.aggregate_deltas = defaultdict(lambda: 0) for delta_tuple, increment_by in self.parameterized_deltas.items(): of, by, value = delta_tuple if increment_by > 0: ParameterizedCount(of, by, value).inc(self.table, increment_by) elif increment_by < 0: ParameterizedCount(of, by, value).dec(self.table, increment_by) # reset flushed buffer self.parameterized_deltas = defaultdict(lambda: 0)