in tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc [556:781]
absl::Status FastWordpieceBuilder::BuildFailureStructure(
const std::vector<TrieVocabToken>& tokens_to_build_trie) {
// Build the set of outgoing edge labels for each trie node (node_id ->
// set<char>). This is needed by BFS because darts-clone does not provide an
// API to enumerate the outgoing links for a node.
SH_ASSIGN_OR_RETURN(
std::vector<absl::flat_hash_set<char>> node_outgoing_edge_labels,
BuildOutgoingEdgeLabelsForTrie(tokens_to_build_trie));
failure_struct_array_.resize(trie_array_.size());
// Initialize the BFS queue.
std::queue<uint32_t> bfs_queue({trie_->kRootNodeId});
if (trie_suffix_root_ != trie_->kRootNodeId) {
// When `suffix_indicator_` is empty, `trie_suffix_root_` will collapse
// with root. In this case, we don't visit it twice.
//
// In addition, we have ensured that `trie_suffix_root_` will never be null.
// See PrepareVocabTokensToBuildTrie().
bfs_queue.push(trie_suffix_root_);
}
// The BFS loop.
while (!bfs_queue.empty()) {
uint32_t parent_id = bfs_queue.front();
bfs_queue.pop();
// Explore the children of the parent node.
//
// Fix the iteration order of the outgoing edges to ensure that the model is
// always built in the same way (i.e., visiting nodes in the same order).
std::vector<char> outgoing_labels_sorted(
node_outgoing_edge_labels[parent_id].begin(),
node_outgoing_edge_labels[parent_id].end());
std::sort(outgoing_labels_sorted.begin(), outgoing_labels_sorted.end());
for (const char edge_label : outgoing_labels_sorted) {
auto child_node = trie_->CreateTraversalCursor(parent_id);
if (!trie_->TryTraverseOneStep(child_node, edge_label)) {
// Should never happen, due to how we built `node_outgoing_edge_labels`;
// see BuildOutgoingEdgeLabelsAlongVocabToken().
return absl::FailedPreconditionError(absl::StrCat(
"Failed to traverse to child following edge ",
absl::string_view(&edge_label, 1), " at parent ", parent_id, "."));
}
if (child_node.node_id == trie_suffix_root_) {
// Avoid visiting `trie_suffix_root_` twice.
continue;
}
// For the child node v, compute failure link f(v) and failure pops F(v).
//
// In the comments below, str(v) is the string on the path from the trie
// root to the node v, and V is the vocabulary used to build the trie.
int child_data_value = -1;
if (trie_->TryGetData(child_node, child_data_value)) {
uint32_t failure_link = trie_suffix_root_;
// Check whether the current node represents a punctuation char.
// Since the current node has data and thus corresponds to some token,
// it must be in the map `node_id_is_punc_map_`
if (!node_id_is_punc_map_.contains(child_node.node_id)) {
return absl::FailedPreconditionError(
"Failed to find if an end node in the trie is a punctuation char "
"in node_id_is_punc_map_. It should never happen.");
}
if (!no_pretokenization_ &&
node_id_is_punc_map_.at(child_node.node_id)) {
// For end-to-end tokenizer, we set the failure link node of every
// punctuation char as a special node trie_punct_failure_link_node_
// which is a dummy node (no parent, no descendants, failure link is
// null). Hence, by detecting the landing node, we know we just
// matched a punctuation char. We then split it as a single word.
failure_link = trie_punct_failure_link_node_;
}
// Case 1 (easy): str(v) is in V. Assume that during tokenization of a
// word, we reached node v, but can't continue further, because the
// current char from the input word does not match any of the edges
// outgoing from v. In that case, str(v) is already the max match, so
// it's the only wordpiece we add to the list of wordpieces we committed
// to. Hence, F(v) = [str(v)]. The next wordpiece from the current word
// is a suffix, so we move to node f(v) = trie_suffix_root_, which
// represents the suffix indicator (e.g., "##"), from where we continue
// the match process. In summary, we have:
// * f(v) = trie_suffix_root_.
// * F(v) = [str(v)].
SH_RETURN_IF_ERROR(AssignFailureLinkAndPops(
/*cur_node=*/child_node.node_id, /*failure_link=*/failure_link,
/*one_step_pops=*/{child_data_value},
/*parent_failure_pops_offset_length=*/
fast_wordpiece_tokenizer_utils::kNullFailurePopsList));
bfs_queue.push(child_node.node_id);
continue;
}
// Case 2 (complex): str(v) is not in V.
//
// Consider the same scenario as in Case 1, where we can't continue
// further from v, but now, str(v) is not a valid wordpiece. Instead,
// we need to consider the wordpieces that the MaxMatch algorithm would
// generate for the beginning of str(v) (these wordpieces are stored in
// F(v)). f(v) (the state we transit to) should correspond to the trie
// node for the remaining suffix of str(v).
//
// We could compute F(v) and f(v) by running the original WordPiece
// algorithm. Instead, we do it even faster, by using F(u) and f(u) (the
// similar info for the parent node u). Intuitively F(v) consists of (1)
// the tokens from F(u) and (2) the possible tokens that the MaxMatch
// algorithm would generate for str(f(u)).c, where str(f(u)) is the suffix
// of str(u) not covered by the concatenation of the tokens from F(u), "."
// means concatenation, and c is the edge label character from u to v.
//
//
// Let u be the parent node, and c be the edge label from u to v. To
// compute f(v) and F(v), the loop below uses a node variable z (called
// `itr_node`) and a list G (called `one_steps_pops`). Initially, z is set
// to be f(u), and G is empty.
// 1. If z is null, f(v) will be null, too (see Note 2 below for what
// this means). We're done.
// 2. Check if there is a trie edge out of node z, for label c, leading
// to node goto(z, c). If so, set f(v) = goto(z,c) and F(v) = F(u) + G.
// We're done and break.
// 3. Otherwise, collect the pop tokens (by G = G + F(z)) and
// follows the failure link (by z = f(z)).
// 4. Goes to Step 1 and continue the loop.
//
// Note 1: processing node v depends on the info for nodes z that are
// closer to the root than v. Due to our use of the BFS traversal, that
// info is guaranteed to exist when we examine node v.
//
// Note 2: f(v) is null means that during the tokenization process of some
// input word, if the trie matching cannot continue at node v, there are
// no failure links that we can follow, and (it can be proved that in such
// a case) the input word can't be tokenized with the current vocab.
//
// For formal discussions and proofs, please refer to the academic paper
// https://arxiv.org/abs/2012.15524
const FailureStruct& parent_fs = failure_struct_array_[parent_id];
if (parent_fs.failure_link != fast_wordpiece_tokenizer_utils::kNullNode) {
std::vector<int> one_step_pops;
auto itr_node = trie_->CreateTraversalCursor(parent_fs.failure_link);
while (true) {
if (trie_->TryTraverseOneStep(itr_node, edge_label)) {
// Set the failure link and failure pops for `child_node`.
SH_RETURN_IF_ERROR(AssignFailureLinkAndPops(
/*cur_node=*/child_node.node_id,
/*failure_link=*/itr_node.node_id, one_step_pops,
parent_fs.failure_pops_offset_length));
break;
}
const FailureStruct& itr_node_fs =
failure_struct_array_[itr_node.node_id];
if (itr_node_fs.failure_link ==
fast_wordpiece_tokenizer_utils::kNullNode) {
// Cannot follow anymore: failure link of `child_node` will be null.
break;
}
// Append the failure pops of `itr_node` to `one_step_pops`.
GetFailurePopsAndAppendToOut(itr_node_fs.failure_pops_offset_length,
one_step_pops);
// Follow the failure link.
trie_->SetTraversalCursor(itr_node, itr_node_fs.failure_link);
}
}
bfs_queue.push(child_node.node_id);
}
}
if (!no_pretokenization_ && !suffix_indicator_.empty()) {
// Rewire trie links along suffix_indicator_.
// If the suffix indicator contains a punctuation char, let `u`--(`c`)-->`v`
// be the first trie edge along the suffix indicator such that the edge
// label (i.e. `c`) is a punctuation char. Note that `u`, `v` are trie
// nodes. `c` is the edge label. We make the following change:
//
// Case 1: if `u` is the root, we remove the trie edge from `v` to its child
// along the suffix indicator.
// Case 2: if `u` is not the root, we remove the trie edge from `u` to `v`.
//
// Example 1: if suffix_indicator_ is "##" (as in BERT), we remove the trie
// link from "#" to "##". The goal here is to make sure we match the
// punctuation character "#" as a token by itself, without matching "##"
// (as we split by punctuation, "##" is not a valid token).
// Example 2: if suffix_indicator is "foo#", we remove the trie link from
// "foo" to "foo#".
int cur_pos = 0;
int next_pos = 0;
bool prev_node_id_is_root = false;
auto node = trie_->CreateTraversalCursorPointToRoot();
UChar32 c;
int suffix_indicator_length = suffix_indicator_.size();
while (cur_pos < suffix_indicator_length) {
next_pos = cur_pos;
U8_NEXT(suffix_indicator_, next_pos, suffix_indicator_length, c);
prev_node_id_is_root = (node.node_id == trie_->kRootNodeId);
absl::string_view cur_unicode_char(suffix_indicator_.data() + cur_pos,
next_pos - cur_pos);
if (!trie_->TryTraverseSeveralSteps(node, cur_unicode_char)) {
return absl::FailedPreconditionError(
"Cannot locate a character in suffix_indicator_. It should never "
"happen.");
}
if (fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar(c)) {
// If the previous node is a root node, read the next char to break the
// link from the current punctuation char to its next child node.
if (prev_node_id_is_root) {
cur_pos = next_pos;
U8_FWD_1(suffix_indicator_, next_pos, suffix_indicator_length);
const absl::string_view next_unicode_char(
suffix_indicator_.data() + cur_pos, next_pos - cur_pos);
auto child_node = node;
if (!trie_->TryTraverseSeveralSteps(child_node, next_unicode_char)) {
return absl::FailedPreconditionError(
"Cannot locate a character in suffix_indicator_. It should "
"never happen.");
}
BreakTrieLinkFromParentToChild(child_node.node_id);
} else {
BreakTrieLinkFromParentToChild(node.node_id);
}
break;
}
cur_pos = next_pos;
}
}
return absl::OkStatus();
}