csrc/suffix_decoding/suffix_tree.cc (431 lines of code) (raw):
// Copyright 2025 Snowflake Inc.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <iostream>
#include <queue>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "suffix_tree.h"
#define CHECK_OR_RETURN(cond, msg) if (!(cond)) return msg;
SuffixTree::SuffixTree(int max_depth)
: _max_depth(max_depth), _root(new Node()) {
}
// Append a new element to a new or existing sequence.
void SuffixTree::append(int seq_id, int token) {
// Initialize the sequence if it doesn't exist.
_seqs.try_emplace(seq_id);
_active_nodes.try_emplace(seq_id);
// Insert a new active node at the root.
_active_nodes[seq_id].push_back(_root.get());
_root->endpoints[seq_id] = static_cast<int>(_seqs[seq_id].size());
_root->count += 1;
// Ensure the number of active nodes doesn't exceed max_depth.
if (_active_nodes[seq_id].size() > static_cast<size_t>(_max_depth)) {
_active_nodes[seq_id].pop_front();
}
_seqs[seq_id].push_back(token);
// Iterate over all active nodes for this sequence.
for (size_t i = 0; i < _active_nodes[seq_id].size(); ++i) {
Node* node = _active_nodes[seq_id][i];
Node* child = nullptr;
if (node->children.contains(token)) {
child = node->children[token].get();
}
assert(node->endpoints.contains(seq_id));
assert(node->endpoints[seq_id] == _seqs[seq_id].size() - 1);
if (child == nullptr) {
// No existing child node for the new token.
if (node->count == 1 && node != _root.get()) {
// The active node has count = 1, which means the only suffix that ends here is the
// one that's being extended right now. Then this node should be a leaf node, and
// we can simply extend the length of this node.
assert(node->children.empty());
assert(node->ref_seq == seq_id);
node->length += 1;
node->endpoints[seq_id] += 1;
} else {
// Either this is the root node, or the current suffix is not the only one that
// ends here. Either case, we need to extend the current suffix into a new child.
Node* new_child = new Node();
new_child->token = token;
new_child->parent = node;
new_child->count = 1;
new_child->endpoints[seq_id] = static_cast<int>(_seqs[seq_id].size());
new_child->ref_seq = seq_id;
new_child->ref_idx = static_cast<int>(_seqs[seq_id].size()) - 1;
new_child->length = 1;
node->children.emplace(token, new_child);
node->endpoints.erase(seq_id);
_active_nodes[seq_id][i] = new_child;
}
}
else if (node->count == child->count + 1 && node != _root.get()) {
// The active node has a child for the new token, and the child's count is exactly one
// fewer than the active node's count. Since the suffix for the active node ends here,
// that means all other suffixes that pass through this node must go to that child.
assert(node->children.size() == 1); // The active node should have only one child.
assert(node->endpoints.size() == 1); // Only the current suffix should end here.
if (child->length == 1) {
// The child only has length 1. If we append the new token to the current suffix,
// then it will perfectly overlap with the child. In this case, we should just fuse
// the current suffix into the child and eliminate the current node.
Node* parent = node->parent;
// Update child to take the place of the current node.
child->token = node->token;
child->count += 1; // Current suffix extends into the child
child->length = node->length + 1;
child->endpoints[seq_id] = static_cast<int>(_seqs[seq_id].size());
child->ref_seq = seq_id;
child->ref_idx = static_cast<int>(_seqs[seq_id].size()) - child->length;
child->parent = parent;
// Give ownership of child pointer to parent and should also free the current node.
assert(parent->children.contains(child->token));
assert(parent->children[child->token].get() == node);
Node* tmp = node->children[token].release();
parent->children[child->token].reset(tmp);
// Replace active node with child node.
_active_nodes[seq_id][i] = child;
} else {
// The child has length > 1. If we append the new token to the current suffix, then
// it still does not reach the child node. In this case, we keep both nodes but
// extend the length of the current node by 1 into the child node.
node->length += 1;
node->endpoints[seq_id] += 1;
node->ref_seq = seq_id;
node->ref_idx = static_cast<int>(_seqs[seq_id].size()) - node->length;
child->length -= 1;
child->ref_idx += 1;
// The child node's first token should be updated to its second token.
child->token = _seqs[child->ref_seq][child->ref_idx];
if (child->token != token) {
Node* tmp = node->children[token].release();
node->children.emplace(child->token, tmp);
node->children.erase(token);
}
}
}
else {
// There is a child for the new token, and should move the active node into that child.
if (child->length == 1) {
// The child node has length 1, just update the active node pointer to it.
node->endpoints.erase(seq_id);
child->count += 1;
child->endpoints[seq_id] = static_cast<int>(_seqs[seq_id].size());
child->ref_seq = seq_id;
child->ref_idx = static_cast<int>(_seqs[seq_id].size()) - 1;
_active_nodes[seq_id][i] = child;
} else {
// The child node has length > 1. If we extend the current suffix into it, then it
// must be split into a segment of length 1 and another segment with the remainder.
Node* new_node = new Node();
new_node->token = token;
new_node->count = child->count + 1;
new_node->parent = node;
new_node->length = 1;
new_node->endpoints[seq_id] = static_cast<int>(_seqs[seq_id].size());
new_node->ref_seq = seq_id;
new_node->ref_idx = static_cast<int>(_seqs[seq_id].size()) - new_node->length;
// The child node's first token should be updated to its second token.
child->token = _seqs[child->ref_seq][child->ref_idx + 1];
Node* tmp = node->children[token].release();
new_node->children.emplace(child->token, tmp);
node->children[token].reset(new_node);
node->endpoints.erase(seq_id);
child->parent = new_node;
child->length -= 1;
child->ref_idx += 1;
_active_nodes[seq_id][i] = new_node;
}
}
}
}
// Extend a new or existing sequence.
void SuffixTree::extend(int seq_id, const std::vector<int>& tokens) {
for (int token : tokens) {
append(seq_id, token);
}
}
// Remove an existing sequence.
void SuffixTree::remove(int seq_id) {
const std::vector<int>& seq = _seqs[seq_id];
std::vector<Node*> path; // Declare here to avoid repeated allocations.
// Loop through all suffix starting indices.
for (int start = 0; start < seq.size(); start++) {
Node *node = _root.get();
node->count--;
int idx = start;
path.clear();
// Loop through the nodes for this suffix.
while (idx < seq.size()) {
int token = seq[idx];
if (!node->children.contains(token)) {
break;
}
Node* child = node->children[token].get();
assert(child->count > 0);
child->count--;
if (child->count == 0) {
node->children.erase(token);
break;
}
if (child->endpoints.contains(seq_id)) {
child->endpoints.erase(seq_id);
}
idx += child->length;
node = child;
path.push_back(node);
}
// The last visited node may be mergeable with its child.
if (node != _root.get() && node->children.size() == 1) {
const auto& it = *node->children.begin();
std::unique_ptr<Node>& child_uptr = node->children[it.first];
if (node->count == child_uptr->count) {
// Merge node into child.
child_uptr->token = node->token;
child_uptr->length += node->length;
child_uptr->ref_idx -= node->length;
child_uptr->parent = node->parent;
path.back() = node = child_uptr.release();
node->parent->children[node->token].reset(node);
}
}
// ref_seq and ref_idx of all nodes in the path may need to be updated.
// 1. Go to an arbitrary leaf to get its endpoints.
Node* leaf = node;
int distance = 0; // Distance from node to leaf.
while (!leaf->children.empty()) {
leaf = (*leaf->children.begin()).second.get();
distance += leaf->length;
}
// 2. Pick an arbitrary endpoint for the reference sequence and index.
if (leaf->endpoints.empty() || leaf->endpoints.contains(seq_id)) {
// Still need to visit this leaf later when removing this sequence.
// We can skip updating the refs until the next time it's visited.
continue;
}
const auto& ref = *leaf->endpoints.begin();
// 3. Go back up the path to update all nodes' refs.
int32_t ref_seq = ref.first;
int32_t ref_idx = ref.second - distance;
while (!path.empty()) {
Node* n = path.back();
path.pop_back();
ref_idx -= n->length;
if (n->ref_seq == seq_id) {
n->ref_seq = ref_seq;
n->ref_idx = ref_idx;
}
}
}
_seqs.erase(seq_id);
_active_nodes.erase(seq_id);
}
Candidate SuffixTree::speculate(const std::vector<int>& pattern,
int max_spec_tokens,
float max_spec_factor,
float max_spec_offset,
float min_token_prob,
bool use_tree_spec) {
Candidate result;
int start_idx = std::max(static_cast<int>(pattern.size()) - _max_depth, 0);
for ( ; start_idx < pattern.size(); start_idx++) {
auto[node, idx] = _match_pattern(pattern, start_idx);
if (node == nullptr) {
continue;
}
int match_len = static_cast<int>(pattern.size()) - start_idx;
int max_tokens = std::min(max_spec_tokens,
static_cast<int>(match_len * max_spec_factor
+ max_spec_offset + 1e-6));
max_tokens = std::max(max_tokens, 0);
Candidate candidate;
if (use_tree_spec) {
candidate = _speculate_tree(node, idx, max_tokens, min_token_prob);
} else {
candidate = _speculate_path(node, idx, max_tokens, min_token_prob);
}
if (candidate.score > result.score) {
result = std::move(candidate);
result.match_len = match_len;
}
}
return result;
}
std::string SuffixTree::check_integrity() {
// 1. Check structural integrity of all nodes.
std::queue<Node*> queue;
queue.push(_root.get());
while (!queue.empty()) {
Node* node = queue.front();
queue.pop();
std::string ret = _check_node_integrity(node);
if (!ret.empty()) {
return ret;
}
for (const auto& [token, child] : node->children) {
queue.push(child.get());
}
}
// 2. Check all sequences are represented in the tree.
std::unordered_map<Node*, int64_t> visit_count;
for (int seq_id = 0; seq_id < _seqs.size(); seq_id++) {
const std::vector<int>& seq = _seqs[seq_id];
// Loop through all suffix starting indices.
for (int start = 0; start < seq.size(); start++) {
int idx = start;
// Traverse the tree along this suffix.
Node* node = _root.get();
visit_count[node]++;
while (idx < seq.size() && idx - start < _max_depth) {
CHECK_OR_RETURN(node->children.contains(seq[idx]),
"missing child node for sequence");
node = node->children[seq[idx]].get();
visit_count[node]++;
CHECK_OR_RETURN(idx + node->length <= seq.size(),
"path exceeds sequence length");
for (int i = 0; i < node->length; ++i) {
int ref_seq = node->ref_seq;
int ref_idx = node->ref_idx + i;
CHECK_OR_RETURN(seq[idx + i] == _seqs[ref_seq][ref_idx],
"path does not match sequence tokens");
}
idx += node->length;
}
// The last node on this path should have an endpoint.
CHECK_OR_RETURN(node->endpoints.contains(seq_id),
"missing endpoint for sequence");
}
}
// 3. Check all nodes were visited the correct number of times.
assert(queue.empty());
queue.push(_root.get());
while (!queue.empty()) {
Node* node = queue.front();
queue.pop();
CHECK_OR_RETURN(node->count == visit_count[node],
"node count does not match visit count");
for (const auto& [token, child] : node->children) {
queue.push(child.get());
}
}
return "";
}
std::string SuffixTree::_check_node_integrity(Node* node) {
int64_t children_count = 0;
for (const auto& [token, child] : node->children) {
// Do all my children have me as their parent?
CHECK_OR_RETURN(child->parent == node, "child node has incorrect parent pointer");
children_count++;
}
// Is my counter at least the sum of my childrens' counters?
CHECK_OR_RETURN(children_count <= node->count, "node count is less than sum children counts");
if (node == _root.get()) {
// Root node can stop here after some simple checks.
CHECK_OR_RETURN(node->count >= 0, "root node has negative count");
CHECK_OR_RETURN(node->parent == nullptr, "root node has non-null parent pointer");
CHECK_OR_RETURN(node->length == 0, "root node has non-zero length");
CHECK_OR_RETURN(node->endpoints.empty(), "root node has non-empty endpoints");
CHECK_OR_RETURN(node->ref_idx == -1, "root node has invalid ref_idx");
return "";
}
// Is my length positive? Otherwise, I shouldn't exist.
CHECK_OR_RETURN(node->length > 0, "internal node has non-positive length");
// Is my count positive? Otherwise, I shouldn't exist.
CHECK_OR_RETURN(node->count > 0, "internal node has non-positive count");
// Are all my children's counts less than mine? If equal, then we should have been merged.
for (const auto& [token, child] : node->children) {
CHECK_OR_RETURN(
child->count < node->count, "internal node count is not greater than child count");
}
// Check my reference sequence and index.
CHECK_OR_RETURN(_seqs.count(node->ref_seq), "internal node has invalid ref_seq");
CHECK_OR_RETURN(node->ref_idx >= 0, "internal node has invalid ref_idx");
CHECK_OR_RETURN(node->ref_idx + node->length <= _seqs[node->ref_seq].size(),
"internal node has invalid token range");
// Check my first token is correct.
CHECK_OR_RETURN(node->token == _seqs[node->ref_seq][node->ref_idx],
"internal node has incorrect first token");
// Check I am my parent's child.
CHECK_OR_RETURN(node->parent->children.contains(node->token),
"internal node is not a child of parent node");
CHECK_OR_RETURN(node->parent->children[node->token].get() == node,
"parent node has incorrect child pointer");
// Check all my endpoint references are correct.
for (auto [seq_id, end_idx] : node->endpoints) {
CHECK_OR_RETURN(_seqs.count(seq_id), "node endpoint refers to nonexistent sequence");
CHECK_OR_RETURN(end_idx > 0 && end_idx <= _seqs[seq_id].size(), "invalid endpoint index");
// Check all tokens from the start of the suffix to the endpoint.
Node* n = node;
int idx = end_idx;
do {
CHECK_OR_RETURN(n->length <= idx, "invalid endpoint length");
idx -= n->length;
for (int i = 0; i < n->length; ++i) {
int tok = _seqs[n->ref_seq][n->ref_idx + i];
CHECK_OR_RETURN(_seqs[seq_id][idx + i] == tok, "invalid endpoint token");
}
n = n->parent;
} while (n != nullptr);
}
return "";
}
std::pair<Node*, int> SuffixTree::_match_pattern(
const std::vector<int>& pattern, int start_idx) {
Node* node = _root.get();
int idx = 0;
for (int i = start_idx; i < pattern.size(); i++) {
int c = pattern[i];
if (idx >= node->length) {
if (!node->children.contains(c)) {
return {nullptr, -1};
}
node = node->children[c].get();
idx = 0;
}
assert(idx < node->length);
if (_seqs[node->ref_seq][node->ref_idx + idx] != c) {
return {nullptr, -1};
}
idx++;
}
return {node, idx};
}
Candidate SuffixTree::_speculate_path(Node* node, int idx,
int max_spec_tokens,
float min_token_prob) {
Candidate ret;
float prob = 1.0f;
while (ret.token_ids.size() < max_spec_tokens && prob >= min_token_prob) {
if (idx < node->length) {
// Use previous token index as parent; if none, mark as -1.
ret.parents.push_back(static_cast<int>(ret.token_ids.size()) - 1);
int token = _seqs[node->ref_seq][node->ref_idx + idx];
ret.token_ids.push_back(token);
ret.probs.push_back(prob);
ret.score += prob;
idx++;
} else {
Node* child = nullptr;
int64_t count = 0;
// Choose the child with the maximum count.
for (const auto& kv : node->children) {
Node* ch = kv.second.get();
if (ch->count > count) {
child = ch;
count = ch->count;
}
}
if (child == nullptr) {
break;
}
prob *= static_cast<float>(count) / node->count;
node = child;
idx = 0;
}
}
return ret;
}
struct HeapItem {
float prob;
Node* node;
int idx;
int parent; // index in the candidate token list; -1 if none.
HeapItem(float p, Node* n, int i, int par)
: prob(p), node(n), idx(i), parent(par) {}
};
struct HeapItemCompare {
bool operator()(const HeapItem& a, const HeapItem& b) const {
// In C++ priority_queue by default returns the largest element.
// Thus, we compare probabilities so that the highest prob is returned.
return a.prob < b.prob;
}
};
// Get a candidate token tree using a priority queue.
Candidate SuffixTree::_speculate_tree(Node* node, int idx,
int max_spec_tokens,
float min_token_prob) {
Candidate ret;
std::priority_queue<HeapItem, std::vector<HeapItem>, HeapItemCompare> queue;
queue.emplace(1.0, node, idx, -1);
while (ret.token_ids.size() < max_spec_tokens && !queue.empty()) {
HeapItem item = queue.top();
queue.pop();
if (item.idx < item.node->length) {
int token = _seqs[item.node->ref_seq][item.node->ref_idx + item.idx];
ret.token_ids.push_back(token);
ret.parents.push_back(item.parent);
ret.probs.push_back(item.prob);
ret.score += item.prob;
queue.emplace(item.prob, item.node, item.idx + 1,
static_cast<int>(ret.token_ids.size()) - 1);
} else {
for (const auto& kv : item.node->children) {
Node* child = kv.second.get();
float prob = item.prob * child->count /
static_cast<float>(item.node->count);
if (prob >= min_token_prob) {
queue.emplace(prob, child, 0, item.parent);
}
}
}
}
return ret;
}
size_t SuffixTree::estimate_memory() const {
size_t total = sizeof(*this);
std::vector<Node*> stack;
stack.push_back(_root.get());
while (!stack.empty()) {
Node* node = stack.back();
stack.pop_back();
total += node->memory_usage();
for (const auto& [token, child] : node->children) {
stack.push_back(child.get());
}
}
for (const auto& [seq_id, seq] : _seqs) {
total += sizeof(decltype(seq)::value_type) * seq.capacity();
}
for (const auto& [seq_id, active_nodes] : _active_nodes) {
total += sizeof(decltype(active_nodes)::value_type) * active_nodes.size();
}
return total;
}