pdq/cpp/index/mih.h (200 lines of code) (raw):

// ================================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // ================================================================ #ifndef MIH_H #define MIH_H #include <pdq/cpp/common/pdqhashtypes.h> #include <stdexcept> #include <map> #include <set> #include <vector> namespace facebook { namespace pdq { namespace index { // ================================================================ // MUTUALLY-INDEXED HASHING FOR 256-BIT HASHES // // ---------------------------------------------------------------- // References: // Mutually-indexed hashing by Norouzi et al. 2014: // * https://www.cs.toronto.edu/~norouzi/research/papers/multi_index_hashing.pdf // * https://norouzi.github.io/research/posters/mih_poster.pdf // * This is a from-scratch source-code implementation based on the paper. // // ---------------------------------------------------------------- // Size constraints: // // 'Slots' are 16-bit words. Maximum distance we support for non-brute-force // search is MIH_MAX_SLOTWISE_D. This corresponds to max hashwise distance of // MIH_MAX_D since that's the largest d such that floor(d/16) <= // MIH_MAX_SLOTWISE_D. // // The reason, in turn, for this is the expense of finding all hamming-distance // nearest neighbors. For more information please see hashing/pdq/README-MIH.md // in this repo. // ================================================================ const int MIH_MAX_D = 63; const int MIH_MAX_SLOTWISE_D = 3; // Implemented entirely within the header file since this is a template class. template <typename Metadata> class MIH256 { private: // ---------------------------------------------------------------- // MIH data: // 1. Array of all hashes+metadata in the index. std::vector<std::pair<facebook::pdq::hashing::Hash256, Metadata>> _allHashes; // 2. For each slot index i=0..15: // For each of up to 65,536 possible slot values v at that index: // Hashset of indices within the _allHashes array of all hashes // having slot value v at slot index i. std::vector<std::map<facebook::pdq::hashing::Hash16, std::vector<int>>> _slotValuesToIndices; public: // ---------------------------------------------------------------- MIH256() : _slotValuesToIndices(facebook::pdq::hashing::HASH256_NUM_WORDS) {} // ---------------------------------------------------------------- // Let STL do the work of freeing its containers. ~MIH256() {} private: // Disallow copying MIH256(const MIH256& /*that*/) {} void operator=(const facebook::pdq::hashing::Hash256& /*that*/) {} public: // ---------------------------------------------------------------- int size() { return _allHashes.size(); } std::vector<std::pair<facebook::pdq::hashing::Hash256,Metadata>> get() { return _allHashes; } // --------------------------------------------------------------- // BULK INSERTION void insertAll( const std::vector<std::pair<facebook::pdq::hashing::Hash256, Metadata>>& pairs) { for (auto it : pairs) { insert(it.first, it.second); } } // --------------------------------------------------------------- // HASH INSERTION void insert(const facebook::pdq::hashing::Hash256& hash, Metadata metadata) { int sizeBeforeInsert = _allHashes.size(); for (int i = 0; i < facebook::pdq::hashing::HASH256_NUM_WORDS; i++) { _slotValuesToIndices[i][hash.w[i]].push_back(sizeBeforeInsert); } _allHashes.push_back(std::make_pair(hash, metadata)); } // ---------------------------------------------------------------- void queryAllNeighborAux( facebook::pdq::hashing::Hash16 neighbor, const std::map<facebook::pdq::hashing::Hash16, std::vector<int>>& indicesForSlotValue, std::set<int>& indices) const { const auto found = indicesForSlotValue.find(neighbor); if (found != indicesForSlotValue.end()) { indices.insert(found->second.begin(), found->second.end()); } } void queryAll0( facebook::pdq::hashing::Hash16 neighbor0, const std::map<facebook::pdq::hashing::Hash16, std::vector<int>>& indicesForSlotValue, std::set<int>& indices) const { queryAllNeighborAux(neighbor0, indicesForSlotValue, indices); } void queryAll1( facebook::pdq::hashing::Hash16 neighbor0, const std::map<facebook::pdq::hashing::Hash16, std::vector<int>>& indicesForSlotValue, std::set<int>& indices) const { queryAllNeighborAux(neighbor0, indicesForSlotValue, indices); for (int i1 = 0; i1 < 16; i1++) { facebook::pdq::hashing::Hash16 neighbor1 = neighbor0 ^ (1 << i1); queryAllNeighborAux(neighbor1, indicesForSlotValue, indices); } } void queryAll2( facebook::pdq::hashing::Hash16 neighbor0, const std::map<facebook::pdq::hashing::Hash16, std::vector<int>>& indicesForSlotValue, std::set<int>& indices) const { queryAllNeighborAux(neighbor0, indicesForSlotValue, indices); for (int i1 = 0; i1 < 16; i1++) { facebook::pdq::hashing::Hash16 neighbor1 = neighbor0 ^ (1 << i1); queryAllNeighborAux(neighbor1, indicesForSlotValue, indices); for (int i2 = i1 + 1; i2 < 16; i2++) { facebook::pdq::hashing::Hash16 neighbor2 = neighbor1 ^ (1 << i2); queryAllNeighborAux(neighbor2, indicesForSlotValue, indices); } } } void queryAll3( facebook::pdq::hashing::Hash16 neighbor0, const std::map<facebook::pdq::hashing::Hash16, std::vector<int>>& indicesForSlotValue, std::set<int>& indices) const { queryAllNeighborAux(neighbor0, indicesForSlotValue, indices); for (int i1 = 0; i1 < 16; i1++) { facebook::pdq::hashing::Hash16 neighbor1 = neighbor0 ^ (1 << i1); queryAllNeighborAux(neighbor1, indicesForSlotValue, indices); for (int i2 = i1 + 1; i2 < 16; i2++) { facebook::pdq::hashing::Hash16 neighbor2 = neighbor1 ^ (1 << i2); queryAllNeighborAux(neighbor2, indicesForSlotValue, indices); for (int i3 = i2 + 1; i3 < 16; i3++) { facebook::pdq::hashing::Hash16 neighbor3 = neighbor2 ^ (1 << i3); queryAllNeighborAux(neighbor3, indicesForSlotValue, indices); } } } } // ---------------------------------------------------------------- // HASH QUERY // // MIH query algorithm: // Given needle hash n // For each slot index i: // Get slot value v of n at index i // Find the array indices of hashes in the MIH whose i'th slot value // is within slotwise distance of v. Do this by finding all the // nearest-neighbor values w of v and finding the indices of all // hashes having value w at slot index i. void queryAll( const facebook::pdq::hashing::Hash256& needle, int d, std::vector<std::pair<facebook::pdq::hashing::Hash256, Metadata>>& matches) const { std::set<int> indices; // Floor of d/16; see comments at top of file: const int slotwise_d = d / 16; // Find candidates for (int i = 0; i < facebook::pdq::hashing::HASH256_NUM_WORDS; i++) { facebook::pdq::hashing::Hash16 slotValue = needle.w[i]; const auto& indicesForSlotValue = _slotValuesToIndices[i]; switch (slotwise_d) { case 0: queryAll0(slotValue, indicesForSlotValue, indices); break; case 1: queryAll1(slotValue, indicesForSlotValue, indices); break; case 2: queryAll2(slotValue, indicesForSlotValue, indices); break; case 3: queryAll3(slotValue, indicesForSlotValue, indices); break; default: throw std::runtime_error( "PDQ MIH queryAll: distance threshold out of bounds. " "Please use linear search."); break; } } // Prune candidates for (auto idx : indices) { const facebook::pdq::hashing::Hash256& hash = _allHashes[idx].first; const Metadata& metadata = _allHashes[idx].second; if (hash.hammingDistance(needle) <= d) { matches.push_back(std::make_pair(hash, metadata)); } } } // ---------------------------------------------------------------- // LINEAR SEARCH void bruteForceQueryAll( const facebook::pdq::hashing::Hash256& needle, int d, std::vector<std::pair<facebook::pdq::hashing::Hash256, Metadata>>& matches) const { for (auto it : _allHashes) { auto& hash = it.first; if (hash.hammingDistance(needle) <= d) { Metadata metadata = it.second; matches.push_back(std::make_pair(hash, metadata)); } } } bool bruteForceQueryAny( const facebook::pdq::hashing::Hash256& needle, int d, facebook::pdq::hashing::Hash256& match) const { for (auto it : _allHashes) { auto& hash = it.first; if (hash.hammingDistanceLE(needle, d)) { match = hash; return true; } } return false; } // ---------------------------------------------------------------- // OPS/REGRESSION ROUTINE void dump() { printf("ALL HASHES:\n"); for (auto it : _allHashes) { facebook::pdq::hashing::Hash256& hash = it.first; printf("%s\n", hash.format().c_str()); fflush(stdout); } printf("MULTI-INDICES:\n"); for (int i = 0; i < facebook::pdq::hashing::HASH256_NUM_WORDS; i++) { printf("\n"); printf("--------------- slot_index=%d\n", i); for (auto it1 : _slotValuesToIndices[i]) { facebook::pdq::hashing::Hash16 slotValue = it1.first; std::vector<int> indices = it1.second; printf("slot_value=%04hx\n", slotValue); for (auto it2 : indices) { printf(" %d\n", it2); fflush(stdout); } fflush(stdout); } fflush(stdout); } } }; // end class MIH256 } // namespace index } // namespace pdq } // namespace facebook #endif // MIH_H