python-threatexchange/threatexchange/hashing/pdq_faiss_matcher.py (140 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import typing as t
import faiss # type: ignore
import binascii
import numpy # type: ignore
from abc import ABC, abstractmethod
from .pdq_utils import BITS_IN_PDQ
PDQ_HASH_TYPE = t.Union[str, bytes]
def uint64_to_int64(as_uint64: int):
"""
Returns the int64 number represented by the same byte representation as the the provided integer if it was understood to
be a uint64 value.
"""
return numpy.uint64(as_uint64).astype(numpy.int64).item()
def int64_to_uint64(as_int64: int):
"""
Returns the uint64 number represented by the same byte representation as the the provided integer if it was understood to
be a int64 value.
"""
return numpy.int64(as_int64).astype(numpy.uint64).item()
class PDQHashIndex(ABC):
@abstractmethod
def __init__(self, faiss_index: faiss.IndexBinary) -> None:
self.faiss_index = faiss_index
super().__init__()
@abstractmethod
def hash_at(self, idx: int):
"""
Returns the hash located at the given index. The index order is determined by the initial order of hashes used to
create this index.
"""
pass
@abstractmethod
def add(self, hashes: t.Iterable[PDQ_HASH_TYPE], custom_ids: t.Iterable[int]):
"""
Adds hashes and their custom ids to the PDQ index.
"""
pass
def search(
self,
queries: t.Sequence[PDQ_HASH_TYPE],
threshhold: int,
return_as_ids: bool = False,
):
"""
Searches this index for PDQ hashes within the index that are no more than the threshold away from the query hashes by
hamming distance.
Parameters
----------
queries: sequence of PDQ Hashes
The PDQ hashes to query against the index
threshold: int
Threshold value to use for this search. The hamming distance between the result hashes and the related query will
be no more than the threshold value. i.e., hamming_dist(q_i,r_i_j) <= threshold.
return_as_ids: boolean
whether the return values should be the index ids for the matching items. Defaults to false.
Returns
-------
sequence of matches per query
For each query provided in queries, the returned sequence will contain a sequence of matches within the index
that were within threshold hamming distance of that query. These matches will either be a hexstring of the hash
by default, or the index ids of the matches if `return_as_ids` is True. The inner sequences may be empty in the
case of no hashes within the index. The same PDQ hash may also appear in more than one inner sequence if it
matches multiple query hashes.
For example the hash "000000000000000000000000000000000000000000000000000000000000FFFF" would match both
"00000000000000000000000000000000000000000000000000000000FFFFFFFF" and
"0000000000000000000000000000000000000000000000000000000000000000" for a threshold of 16. Thus it would appear in
the entry for both the hashes if they were both in the queries list.
"""
query_vectors = [
numpy.frombuffer(binascii.unhexlify(q), dtype=numpy.uint8) for q in queries
]
qs = numpy.array(query_vectors)
limits, _, I = self.faiss_index.range_search(qs, threshhold + 1)
if return_as_ids:
# for custom ids, we understood them initially as uint64 numbers and then coerced them internally to be signed
# int64s, so we need to reverse this before returning them back to the caller. For non custom ids, this will
# effectively return the same result
output_fn: t.Callable[[int], t.Any] = int64_to_uint64
else:
output_fn = self.hash_at
return [
[output_fn(idx.item()) for idx in I[limits[i] : limits[i + 1]]]
for i in range(len(query_vectors))
]
def search_with_distance_in_result(
self,
queries: t.Sequence[str],
threshhold: int,
):
"""
Search method that return a mapping from query_str => (id, hash, distance)
This implementation is the same as `search` above however instead of returning just the sequence of matches
per query it returns a mapping from query strings to a list of matched hashes (or ids) and distances
e.g.
result = {
"000000000000000000000000000000000000000000000000000000000000FFFF": [
(12345678901, "00000000000000000000000000000000000000000000000000000000FFFFFFFF", 16.0)
]
}
"""
query_vectors = [
numpy.frombuffer(binascii.unhexlify(q), dtype=numpy.uint8) for q in queries
]
qs = numpy.array(query_vectors)
limits, similarities, I = self.faiss_index.range_search(qs, threshhold + 1)
# for custom ids, we understood them initially as uint64 numbers and then coerced them internally to be signed
# int64s, so we need to reverse this before returning them back to the caller. For non custom ids, this will
# effectively return the same result
output_fn: t.Callable[[int], t.Any] = int64_to_uint64
result = {}
for i, query in enumerate(queries):
match_tuples = []
matches = [idx.item() for idx in I[limits[i] : limits[i + 1]]]
distances = [idx for idx in similarities[limits[i] : limits[i + 1]]]
for match, distance in zip(matches, distances):
# (Id, Hash, Distance)
match_tuples.append((output_fn(match), self.hash_at(match), distance))
result[query] = match_tuples
return result
def __getstate__(self):
data = faiss.serialize_index_binary(self.faiss_index)
return data
def __setstate__(self, data):
self.faiss_index = faiss.deserialize_index_binary(data)
class PDQFlatHashIndex(PDQHashIndex):
"""
Wrapper around an faiss binary index for use with searching for similar PDQ hashes
The "flat" variant uses an exhaustive search approach that may use less memory than other approaches and may be more
performant when using larger thresholds for PDQ similarity.
"""
def __init__(self):
faiss_index = faiss.IndexBinaryIDMap2(
faiss.index_binary_factory(BITS_IN_PDQ, "BFlat")
)
super().__init__(faiss_index)
def add(self, hashes: t.Iterable[PDQ_HASH_TYPE], custom_ids: t.Iterable[int]):
"""
Parameters
----------
hashes: sequence of PDQ Hashes
The PDQ hashes to create the index with
custom_ids: sequence of custom ids for the PDQ Hashes
Sequence of custom id values to use for the PDQ hashes for any
method relating to indexes (e.g., hash_at). If provided, the nth item in
custom_ids will be used as the id for the nth hash in hashes. If not provided
then the ids for the hashes will be assumed to be their respective index
in hashes (i.e., the nth hash would have id n, starting from 0).
"""
hash_bytes = [binascii.unhexlify(hash) for hash in hashes]
vectors = list(
map(lambda h: numpy.frombuffer(h, dtype=numpy.uint8), hash_bytes)
)
i64_ids = list(map(uint64_to_int64, custom_ids))
self.faiss_index.add_with_ids(numpy.array(vectors), numpy.array(i64_ids))
def hash_at(self, idx: int):
i64_id = uint64_to_int64(idx)
vector = self.faiss_index.reconstruct(i64_id)
return binascii.hexlify(vector.tobytes()).decode()
class PDQMultiHashIndex(PDQHashIndex):
"""
Wrapper around an faiss binary index for use with searching for similar PDQ hashes
The "multi" variant uses an the Multi-Index Hashing searching technique employed by faiss's
IndexBinaryMultiHash binary index.
Properties:
nhash: int (optional)
Optional number of hashmaps for the underlaying faiss index to use for
the Multi-Index Hashing lookups.
"""
def __init__(self, nhash: int = 16):
bits_per_hashmap = BITS_IN_PDQ // nhash
faiss_index = faiss.IndexBinaryIDMap2(
faiss.IndexBinaryMultiHash(BITS_IN_PDQ, nhash, bits_per_hashmap)
)
super().__init__(faiss_index)
self.__construct_index_rev_map()
def add(
self,
hashes: t.Iterable[PDQ_HASH_TYPE],
custom_ids: t.Iterable[int],
):
"""
Parameters
----------
hashes: sequence of PDQ Hashes
The PDQ hashes to create the index with
custom_ids: sequence of custom ids for the PDQ Hashes
Sequence of custom id values to use for the PDQ hashes for any
method relating to indexes (e.g., hash_at). If provided, the nth item in
custom_ids will be used as the id for the nth hash in hashes. If not provided
then the ids for the hashes will be assumed to be their respective index
in hashes (i.e., the nth hash would have id n, starting from 0).
Returns
-------
a PDQMultiHashIndex of these hashes
"""
hash_bytes = [binascii.unhexlify(hash) for hash in hashes]
vectors = list(
map(lambda h: numpy.frombuffer(h, dtype=numpy.uint8), hash_bytes)
)
i64_ids = list(map(uint64_to_int64, custom_ids))
self.faiss_index.add_with_ids(numpy.array(vectors), numpy.array(i64_ids))
self.__construct_index_rev_map()
@property
def mih_index(self):
"""
Convenience accessor for the underlaying faiss.IndexBinaryMultiHash index regardless of if it is wrapped in an ID
map or not.
"""
if hasattr(self.faiss_index, "index"):
return faiss.downcast_IndexBinary(self.faiss_index.index)
return self.faiss_index
def search(
self,
queries: t.Sequence[PDQ_HASH_TYPE],
threshhold: int,
return_as_ids: bool = False,
):
self.mih_index.nflip = threshhold // self.mih_index.nhash
return super().search(queries, threshhold, return_as_ids)
def search_with_distance_in_result(
self,
queries: t.Sequence[str],
threshhold: int,
):
self.mih_index.nflip = threshhold // self.mih_index.nhash
return super().search_with_distance_in_result(queries, threshhold)
def hash_at(self, idx: int):
i64_id = uint64_to_int64(idx)
if self.index_rev_map:
index_id = self.index_rev_map[i64_id]
else:
index_id = i64_id
vector = self.mih_index.storage.reconstruct(index_id)
return binascii.hexlify(vector.tobytes()).decode()
def __construct_index_rev_map(self):
"""
Workaround method for creating an in-memory lookup mapping custom ids to internal index id representations. The
rev_map property provided in faiss.IndexBinaryIDMap2 has no accessible `at` or other index lookup methods in swig
and the implementation of `reconstruct` in faiss.IndexBinaryIDMap2 requires the underlaying index to directly
support `reconstruct`, which faiss.IndexBinaryMultiHash does not. Thus this workaround is needed until either the
values in the faiss.IndexBinaryIDMap2 rev_map can be accessed directly or faiss.IndexBinaryMultiHash is directly
supports `reconstruct` calls.
"""
if hasattr(self.faiss_index, "id_map"):
id_map = self.faiss_index.id_map
self.index_rev_map = {id_map.at(i): i for i in range(id_map.size())}
else:
self.index_rev_map = None
def __setstate__(self, data):
super().__setstate__(data)
self.__construct_index_rev_map()