sql_utils/base/general_trie.h (443 lines of code) (raw):

// // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // A trie is a 256-ary tree that represents a set of strings. GeneralTrie is a // class that uses a trie to map each string in some set to a piece of data // of type T. GeneralTrie comes in two flavors: // // * GeneralTrie<T, T NULL_VALUE>, where NULL_VALUE is a value of type // T used as a placeholder at intermediate nodes in the tree, and T must be an // integral type. // // * ClassGeneralTrie<T>, where a default constructed T is used as the null // value, and T must have copy and default constructor, assignment and equality // operators (used only to check against the null value). Hint: Wrapping your // class with linked_ptr is enough to satisfy the requirements, but remember // that a linked_ptr copy is a read-write operation (see linked_ptr.h), so you // should prefer manipulating the returned references to avoid concurrency // problems, in the presence of multiple reading threads. // // Both classes offer exactly the same interface of the GeneralClassImpl class // seen in the beginning of this file. // // Please note that GeneralTrie is not thread safe. #ifndef THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_GENERAL_TRIE_H_ #define THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_GENERAL_TRIE_H_ #include <assert.h> #include <stdio.h> #include <string.h> #include <algorithm> #include <string> #include <utility> #include <vector> #include "absl/base/attributes.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "sql_utils/base/logging.h" namespace bigquery_ml_utils_base { // The GeneralTrieImpl receives two template parameters to be able to model // both the GeneralTrie<class T, T NULL_VALUE> for integral types and // ClassGeneralTrie<T>. This is an implementation trick to add the // ClassGeneralTrie functionality without breaking existing uses of // GeneralTrie<class T, T NULL_VALUE>. You probably don't want to instantiate // GeneralTrieImpl directly, but rather one of these other two classes, defined // in the end of the file. template <class T, class NullValuePolicy> // NullValuePolicy implements Null() class GeneralTrieImpl { public: // Abstract class for objects that can be passed to PreorderTraverse() and // PostorderTraverse(). class Traverser { public: virtual ~Traverser(); virtual void Process(const std::string& s, const T& data) = 0; }; typedef T value_type; typedef std::pair<std::string, T> TrieData; GeneralTrieImpl(); ~GeneralTrieImpl(); // Inserts the given string into the trie (or finds it if it's already // there) and associates a copy of the given data with it. void Insert(absl::string_view key, const T& data); // Returns the data associated with key in the trie, or // NullValuePolicy::Null() if key is not in the trie. const T& GetData(absl::string_view key) const; // Returns a reference to the data associated with key in the trie, or // NullValuePolicy::Null() if key is not in the trie. T& GetData(absl::string_view key); // If `key` is present in the trie, updates the data associated with key and // returns true; else, it makes no change in the trie and returns false. bool SetData(absl::string_view key, const T& data); // Finds the greatest number n such that: // 1. the first n characters of key form a string in the trie; // 2. (n == key.length()) or is_terminator[key[n]] is true. // If there is no n that satisfies these conditions, the method returns // NullValuePolicy::Null(). Otherwise, it returns the data associated with // the first n characters of key and sets *chars_matched to n. is_terminator // should point to an array of 256 bools. You can also pass a null pointer // for is_terminator, in which case the function treats every character as a // terminator, so condition 2 is always satisfied. const T& GetDataForMaximalPrefix(absl::string_view key, int* chars_matched, const bool* is_terminator) const; // Gets all strings (and associated data) matching the given // string. The given string must match in its entirety. Note: empty // input string matches everything in the trie void GetAllMatchingStrings(absl::string_view key, std::vector<TrieData>* outdata) const; // Calls traverser->Process() for each string in the trie. void PreorderTraverse(Traverser* traverser) const { PreorderTraverseDepth(traverser, -1); } // Like PreorderTraverse(), but you can specify the depth of // traversal. void PreorderTraverseDepth(Traverser* traverser, int depth) const { std::string s; Traverse(traverser, &s, depth, true); } // Calls traverser->Process() for each matching string in the trie. void PreorderTraverseAllMatchingStrings(absl::string_view key, Traverser* traverser) const { PreorderTraverseAllMatchingStringsDepth(key, traverser, -1); } // Like PreorderTraverseAllMatchingStrings(), but you can specify // the depth of traversal. Note the semantics of the depth of // traversal here. An exact match is depth 0. If there are no // exact matches, a "root" is still assumed at this exact match, and // the depth is counted from there. void PreorderTraverseAllMatchingStringsDepth(absl::string_view key, Traverser* traverser, int depth) const { TraverseAllMatchingStrings(key, traverser, depth, true); } // Postorder versions of the traversal functions. void PostorderTraverse(Traverser* traverser) const { PostorderTraverseDepth(traverser, -1); } void PostorderTraverseDepth(Traverser* traverser, int depth) const { std::string s; Traverse(traverser, &s, depth, false); } void PostorderTraverseAllMatchingStrings(const char* s, int len, Traverser* traverser) const { TraverseAllMatchingStrings(absl::string_view(s, len), traverser, -1, false); } // Calls traverser->Process() for each string in the trie that is a substring // of s. void TraverseAlongString(const absl::string_view key, Traverser* traverser) const; void Print(int indent) const; // An iterator with an interface identical to the iterators in CompactTrie. class TraverseIterator { public: bool Done() const { return stack_.empty(); } // Note: the referenced key is mutated by each call to Next(). const std::string& Key() const { SQL_CHECK(!Done()); return key_; } const T& Value() const { SQL_CHECK(!Done()); return stack_.back().first->data_; } void Next(); private: friend class GeneralTrieImpl; typedef GeneralTrieImpl<T, NullValuePolicy> NodeT; explicit TraverseIterator(const GeneralTrieImpl<T, NullValuePolicy>* trie); // The stack stores the current path's nodes, deepest node on top (back). // For each node, the index of the next branch to take is stored. std::vector<std::pair<const NodeT*, int> > stack_; std::string key_; }; // Returns an iterator over all keys in lexicographical order. TraverseIterator Traverse() const { return TraverseIterator(this); } private: typedef GeneralTrieImpl<T, NullValuePolicy> NodeT; std::string comppath_; // string compression: must match to continue T data_; const T null_value_instance_; // allows return by reference int min_next_; // next_ array goes from min to max-1 int max_next_; NodeT** next_; // array of "next level of the trie" inline NodeT* Next(int index) const; NodeT* SetNext(int index, NodeT* value); // Calls traverser->Process() for each string in either preorder or postorder. void Traverse(Traverser* traverser, std::string* s, int depth, bool preorder) const; // Calls traverser->Process() for each matching string in either preorder or // postorder. void TraverseAllMatchingStrings(absl::string_view key, Traverser* traverser, int depth, bool preorder) const; // get a ptr to the data corresponding to a given string. // returns NullValuePolicy::Null() if the string is not present in the trie. // Used as a helper method to get and set data const T* GetDataPtr(absl::string_view key) const; // Disallow copy constructor and operator= GeneralTrieImpl(const GeneralTrieImpl<T, NullValuePolicy>&); void operator=(const GeneralTrieImpl<T, NullValuePolicy>&); }; template <class T, class NullValuePolicy> GeneralTrieImpl<T, NullValuePolicy>::Traverser::~Traverser() {} // NullValuePolicy for integral types template <typename T, T NULL_VALUE> struct IntegralNullValuePolicy { static T Null() { return NULL_VALUE; } }; // NullValuePolicy for default constructed classes. template <typename T> struct DefaultConstructedNullValuePolicy { static T Null() { return T(); } }; // // Actual classes that matter to the API users // // GeneralTrie version for integral types (and pointers, which aren't integral // types but are similar in some ways) template <class T, T NULL_VALUE> class GeneralTrie : public GeneralTrieImpl<T, IntegralNullValuePolicy<T, NULL_VALUE> > { public: // A static const member variable can only be initialized in the class // definition if it is an integral type, and we also want to support the case // where T is a pointer. Please see standard, clause 9.4.2, paragraph 4. static const T kNullValue; typedef IntegralNullValuePolicy<T, NULL_VALUE> null_value_policy; }; // See comments on member declaration template <class T, T NULL_VALUE> const T GeneralTrie<T, NULL_VALUE>::kNullValue = NULL_VALUE; // GeneralTrie version that supports classes, such as string, // and linked_ptr<YourClass> (but see hint on top of header). template <class T> class ClassGeneralTrie : public GeneralTrieImpl<T, DefaultConstructedNullValuePolicy<T> > { public: typedef DefaultConstructedNullValuePolicy<T> null_value_policy; }; // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::GeneralTrieImpl() // GeneralTrieImpl<T, NullValuePolicy>::~GeneralTrieImpl() // We just make sure everything is 0, and delete things that might // have been new-ed when we're done. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> GeneralTrieImpl<T, NullValuePolicy>::GeneralTrieImpl() : data_(NullValuePolicy::Null()), null_value_instance_(NullValuePolicy::Null()), min_next_(0), max_next_(0), next_(nullptr) {} template <class T, class NullValuePolicy> GeneralTrieImpl<T, NullValuePolicy>::~GeneralTrieImpl() { for (int i = min_next_; i < max_next_; i++) delete next_[i - min_next_]; // recursively calls the destructor delete[] next_; } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::Next() // GeneralTrieImpl<T, NullValuePolicy>::SetNext() // Indexes into the next_ array. This is slightly non-trivial // because next_ isn't 0-based (it's min_next_-based). In // particular, we may have to move data around if index < min_next_. // We can also have to reallocate. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> typename GeneralTrieImpl<T, NullValuePolicy>::NodeT* GeneralTrieImpl<T, NullValuePolicy>::Next(int index) const { if (index < min_next_ || index >= max_next_) return nullptr; assert(next_); return next_[index - min_next_]; } template <class T, class NullValuePolicy> typename GeneralTrieImpl<T, NullValuePolicy>::NodeT* GeneralTrieImpl<T, NullValuePolicy>::SetNext(int index, NodeT* value) { assert(index >= 0 && index < 256); // index should be a char if (min_next_ >= max_next_) { // inserting first element assert(next_ == nullptr); // or at least, it *should* be next_ = new NodeT*[1]; next_[0] = value; min_next_ = index; max_next_ = index + 1; } else if (index < min_next_) { // need to move array over NodeT** newnext = new NodeT*[max_next_ - index]; for (int i = index; i < max_next_; i++) // range of new next_ newnext[i - index] = Next(i); newnext[0] = value; // do the setting delete[] next_; // replace it with newnext next_ = newnext; min_next_ = index; } else if (index >= max_next_) { // just need to grow array NodeT** newnext = new NodeT*[index + 1 - min_next_]; for (int i = min_next_; i < index; i++) // range of new next_ newnext[i - min_next_] = Next(i); newnext[index - min_next_] = value; // do the setting delete[] next_; next_ = newnext; max_next_ = index + 1; // it's actually 1+max } else { // happy case: we're in range next_[index - min_next_] = value; } return Next(index); // should be the same as "value" } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::Insert() // Adds a string key to a trie. Once we've gotten to the leaf, we set // its data to the specified data. // The only complication is comppath. Intuitively, comppath is // prepended to the "next" array before trying to descend: if comppath // is "arc", you can't follow next['h'] unless your string begins with // "arch". If your string begins "bath" instead, we need to break up // the compression to insert. For example: suppose n has comppath // = "stro" and key = "strong". This works fine: we just follow // n->next['n']. But what if key = "state"? Then the "st" part of // comppath is ok, but not the "ro". We change n->comppath to // "st", and then we introduce a new node c between n and its current // children, and set n->next['r'] = c, and c->comppath = "o", to finish // off the "stro". // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::Insert(absl::string_view key, const T& data) { int diff; if (key.empty()) { // we're at our leaf data_ = data; return; } int slen = key.length(); // Break up compression if we have to if (comppath_.size() >= slen || // compression too long !absl::StartsWith(key, comppath_)) { // or doesn't match for (diff = 0; diff < key.length(); diff++) // pos of mismatch if (comppath_[diff] != key[diff]) break; if (diff == slen) { diff--; // because we don't use '\0' as a child index } NodeT* child = new NodeT(); for (int i = min_next_; i < max_next_; i++) { if (Next(i)) { child->SetNext(i, Next(i)); SetNext(i, nullptr); // not my child anymore } } SetNext(comppath_[diff], child); child->comppath_.assign(comppath_, diff + 1, comppath_.size() - diff - 1); // Remove the end of the comppath comppath_.erase(diff, comppath_.size() - diff); } key.remove_prefix(comppath_.size()); slen -= comppath_.size(); // At this point we know compression matches. // If root has no children, we can just modify comppath (it must have // been empty), and insert the rest of key as a child. Otherwise we follow // the path based on the first char of key (creating a child node first if // need be). int i; for (i = min_next_; i < max_next_; i++) // does root have any kids? if (Next(i)) break; if (i == max_next_) { // no? take over comppath comppath_.assign(key.data(), slen - 1); // w/o kids, comppath was empty NodeT* next = SetNext(key[slen - 1], new NodeT()); // one kid now next->Insert(key.substr(slen), data); } else { // has a kid, and a comppath NodeT* nextnode = Next(key[0]); if (!nextnode) // but not the right kid nextnode = SetNext(key[0], new NodeT()); nextnode->Insert(key.substr(1), data); // continue down the trie } } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::TraverseAlongString() // Calls traverser->Process() for each string in the trie that is a substring // of key. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::TraverseAlongString( const absl::string_view key, Traverser* traverser) const { if (key.empty()) return; const NodeT* node = this; int next_pos = 0; std::string buf; buf.reserve(key.length()); while (node) { if (node->data_ != null_value_instance_) { traverser->Process(buf, node->data_); } if (next_pos == key.size()) break; // Return if we don't match comppath. if (node->comppath_.size() >= (key.length() - next_pos) || !absl::StartsWith(key.substr(next_pos), node->comppath_)) break; // Advance next_pos beyond the comppath. next_pos += node->comppath_.size(); // Update the buf with the traversed portion of s. buf.append(node->comppath_); buf.append(1, key[next_pos]); // Move onto the next node. node = node->Next(key[next_pos]); next_pos++; } } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::Print() // Tries to print the trie in a happy format. Used for debugging. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::Print(int indent) const { if (!comppath_.empty()) { printf("%*s%s\n", indent, "", comppath_.c_str()); indent += comppath_.size(); } for (int i = min_next_; i < max_next_; i++) { if (Next(i)) { printf("%*s%c\n", indent, "", i); Next(i)->Print(indent + 1); } } } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::GetData() // Returns the data associated with key in the trie, or // NullValuePolicy::Null() if key is not in the trie. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> const T& GeneralTrieImpl<T, NullValuePolicy>::GetData( absl::string_view key) const { return *GetDataPtr(key); } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::GetData() // Returns a reference to the data associated with key in the trie, or // NullValuePolicy::Null() if key is not in the trie. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> T& GeneralTrieImpl<T, NullValuePolicy>::GetData(absl::string_view key) { return *const_cast<T*>(GetDataPtr(key)); } template <class T, class NullValuePolicy> const T* GeneralTrieImpl<T, NullValuePolicy>::GetDataPtr( absl::string_view key) const { const NodeT* node = this; int slen = key.length(); int next_pos = 0; while (node) { // Return if we've reached the end of `key`. Note that node->data_ // may be NullValuePolicy::Null(). if (next_pos >= slen) { return &(node->data_); } // Return if we don't match comppath. if (node->comppath_.size() >= key.length() - next_pos || !absl::StartsWith(key.substr(next_pos), node->comppath_)) return &null_value_instance_; // we're done // follow first char after comppath next_pos += node->comppath_.size(); // std::cerr << "comp: " << node->comppath_ << std::endl; node = node->Next(key[next_pos]); // std::cerr << next_pos << ": " << s[next_pos] << std::endl; next_pos++; } return &null_value_instance_; // node is a null pointer; we're done } template <class T, class NullValuePolicy> bool GeneralTrieImpl<T, NullValuePolicy>::SetData(absl::string_view key, const T& data) { const T* temp_ptr = GetDataPtr(key); if (temp_ptr == &null_value_instance_) return false; T* non_const_temp_ptr = const_cast<T*>(temp_ptr); *non_const_temp_ptr = data; return true; } // ---------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::GetDataForMaximalPrefix() // Finds the greatest number n such that: // 1. the first n characters of key form a string in the trie; // 2. (n == key.length()) or is_terminator[key[n]] is true. // If there is no n that satisfies these conditions, the method returns // NullValuePolicy::Null(). Otherwise, it returns the data associated // with the first n characters of key and sets *chars_matched to n. // is_terminator should point to an array of 256 bools. You can also pass // a null pointer for is_terminator, in which case the function treats every // character as a terminator, so condition 2 is always satisfied. // ---------------------------------------------------------------------- template <class T, class NullValuePolicy> const T& GeneralTrieImpl<T, NullValuePolicy>::GetDataForMaximalPrefix( absl::string_view key, int* chars_matched, const bool* is_terminator) const { const NodeT* node = this; int next_pos = 0; const T* matched_data = &null_value_instance_; while (node) { // See whether we have a match here if ((node->data_ != null_value_instance_) && ((next_pos >= key.length()) || (is_terminator == nullptr) || (is_terminator[key[next_pos]]))) { *chars_matched = next_pos; matched_data = &(node->data_); } // Return if we've reached the end of s if (next_pos >= key.length()) { return *matched_data; } // Return if we don't match comppath. if (node->comppath_.size() >= (key.length() - next_pos) || !absl::StartsWith(key.substr(next_pos), node->comppath_)) return *matched_data; // we're done // follow first char after comppath next_pos += node->comppath_.size(); node = node->Next(key[next_pos]); next_pos++; } return *matched_data; // reached a node that is a null pointer; we're done } // --------------------------------------------------------------------- // GeneralTrieImpl<T, NullValuePolicy>::GetAllMatchingStrings() // Returns all matching strings and the associated data. If 's' does // not match in its entirety, nothing is returned. // --------------------------------------------------------------------- template <class T, class NullValuePolicy> class TrieExtractor : public GeneralTrieImpl<T, NullValuePolicy>::Traverser { public: typedef typename GeneralTrieImpl<T, NullValuePolicy>::TrieData TData; explicit TrieExtractor(std::vector<TData>* outdata) : outdata_(outdata) {} void Process(const std::string& s, const T& data) override { outdata_->push_back(std::make_pair(s, data)); } private: std::vector<TData>* outdata_; }; template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::GetAllMatchingStrings( absl::string_view key, std::vector<TrieData>* outdata) const { // cleanup before we start outdata->clear(); TrieExtractor<T, NullValuePolicy> traverser(outdata); PreorderTraverseAllMatchingStringsDepth(key, &traverser, -1); } template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::TraverseAllMatchingStrings( absl::string_view key, Traverser* traverser, int depth, bool preorder) const { // first try to match the input string in its entirety const NodeT* node = this; int next_pos = 0; // next position in s int brkpt = 0; // next position in "node" // if we find a mismatch, we return emptyhanded. // if we find a match, we break out of the loop with // node set to the portion of the tree that has all // the matches for 's'. while (node) { if (next_pos >= key.length()) { // done with input string. Note: empty input string // matches everything in the trie. A break here means // brkpt == 0 which is what we want (the entire // comppath_ at this node is a suffix). break; } const int len_to_compare = std::min(node->comppath_.size(), (key.length() - next_pos)); if (memcmp(node->comppath_.data(), key.data() + next_pos, len_to_compare) != 0) { // mismatch found return; } if ((key.length() - next_pos) <= node->comppath_.size()) { // found a match (prefix of comppath_) brkpt = len_to_compare; break; } // follow first char after comppath next_pos += len_to_compare; node = node->Next(key[next_pos]); ++next_pos; } // if we got here with node == nullptr, we have no matches. if (node == nullptr) return; // we got here => we have one or more matches std::string buf(key); // all of the input string if (node->data_ != null_value_instance_ && brkpt == 0 && preorder) { // this node is a full match by itself traverser->Process(buf, node->data_); } buf.append(node->comppath_.data() + brkpt, node->comppath_.size() - brkpt); for (int i = node->min_next_; i < node->max_next_; i++) { NodeT* child = node->Next(i); if (child != nullptr) { buf.append(1, static_cast<char>(i)); child->Traverse(traverser, &buf, depth, preorder); buf.erase(buf.size() - 1); } } if (node->data_ != null_value_instance_ && brkpt == 0 && !preorder) { // this node is a full match by itself traverser->Process(buf, node->data_); } } template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::Traverse(Traverser* traverser, std::string* s, int depth, bool preorder) const { if (data_ != null_value_instance_ && preorder) { traverser->Process(*s, data_); } if (depth == 0) return; if (depth > 0) --depth; s->append(comppath_); for (int i = min_next_; i < max_next_; i++) { NodeT* child = Next(i); if (child != nullptr) { s->append(1, static_cast<char>(i)); child->Traverse(traverser, s, depth, preorder); s->erase(s->size() - 1); } } s->erase(s->size() - comppath_.size()); if (data_ != null_value_instance_ && !preorder) { traverser->Process(*s, data_); } } template <class T, class NullValuePolicy> GeneralTrieImpl<T, NullValuePolicy>::TraverseIterator::TraverseIterator( const GeneralTrieImpl<T, NullValuePolicy>* trie) { stack_.push_back(std::make_pair(trie, trie->min_next_)); if (trie->data_ == trie->null_value_instance_) { Next(); } } template <class T, class NullValuePolicy> void GeneralTrieImpl<T, NullValuePolicy>::TraverseIterator::Next() { while (!stack_.empty()) { const NodeT* node = stack_.back().first; int c = stack_.back().second; if (c == node->min_next_) { key_.append(node->comppath_); } // Traverse the next branch if there is one. for (; c < node->max_next_; ++c) { const NodeT* child = node->Next(c); if (child) { key_.append(1, static_cast<char>(c)); stack_.back().second = c + 1; stack_.push_back(std::make_pair(child, child->min_next_)); if (child->data_ != child->null_value_instance_) { return; } // Continue with the top-level loop and process the child node. break; } } if (c == node->max_next_) { // Leaving the node, so unwind key and stack. key_.erase(key_.size() - node->comppath_.size()); stack_.pop_back(); if (!stack_.empty()) { key_.erase(key_.size() - 1); } } } } } // namespace bigquery_ml_utils_base #endif // THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_GENERAL_TRIE_H_