absl::Status FastWordpieceBuilder::BuildFailureStructure()

in tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc [556:781]


absl::Status FastWordpieceBuilder::BuildFailureStructure(
    const std::vector<TrieVocabToken>& tokens_to_build_trie) {
  // Build the set of outgoing edge labels for each trie node (node_id ->
  // set<char>). This is needed by BFS because darts-clone does not provide an
  // API to enumerate the outgoing links for a node.
  SH_ASSIGN_OR_RETURN(
      std::vector<absl::flat_hash_set<char>> node_outgoing_edge_labels,
      BuildOutgoingEdgeLabelsForTrie(tokens_to_build_trie));

  failure_struct_array_.resize(trie_array_.size());
  // Initialize the BFS queue.
  std::queue<uint32_t> bfs_queue({trie_->kRootNodeId});
  if (trie_suffix_root_ != trie_->kRootNodeId) {
    // When `suffix_indicator_` is empty, `trie_suffix_root_` will collapse
    // with root. In this case, we don't visit it twice.
    //
    // In addition, we have ensured that `trie_suffix_root_` will never be null.
    // See PrepareVocabTokensToBuildTrie().
    bfs_queue.push(trie_suffix_root_);
  }

  // The BFS loop.
  while (!bfs_queue.empty()) {
    uint32_t parent_id = bfs_queue.front();
    bfs_queue.pop();

    // Explore the children of the parent node.
    //
    // Fix the iteration order of the outgoing edges to ensure that the model is
    // always built in the same way (i.e., visiting nodes in the same order).
    std::vector<char> outgoing_labels_sorted(
        node_outgoing_edge_labels[parent_id].begin(),
        node_outgoing_edge_labels[parent_id].end());
    std::sort(outgoing_labels_sorted.begin(), outgoing_labels_sorted.end());
    for (const char edge_label : outgoing_labels_sorted) {
      auto child_node = trie_->CreateTraversalCursor(parent_id);
      if (!trie_->TryTraverseOneStep(child_node, edge_label)) {
        // Should never happen, due to how we built `node_outgoing_edge_labels`;
        // see BuildOutgoingEdgeLabelsAlongVocabToken().
        return absl::FailedPreconditionError(absl::StrCat(
            "Failed to traverse to child following edge ",
            absl::string_view(&edge_label, 1), " at parent ", parent_id, "."));
      }
      if (child_node.node_id == trie_suffix_root_) {
        // Avoid visiting `trie_suffix_root_` twice.
        continue;
      }

      // For the child node v, compute failure link f(v) and failure pops F(v).
      //
      // In the comments below, str(v) is the string on the path from the trie
      // root to the node v, and V is the vocabulary used to build the trie.

      int child_data_value = -1;
      if (trie_->TryGetData(child_node, child_data_value)) {
        uint32_t failure_link = trie_suffix_root_;
        // Check whether the current node represents a punctuation char.
        // Since the current node has data and thus corresponds to some token,
        // it must be in the map `node_id_is_punc_map_`
        if (!node_id_is_punc_map_.contains(child_node.node_id)) {
          return absl::FailedPreconditionError(
              "Failed to find if an end node in the trie is a punctuation char "
              "in node_id_is_punc_map_. It should never happen.");
        }
        if (!no_pretokenization_ &&
            node_id_is_punc_map_.at(child_node.node_id)) {
          // For end-to-end tokenizer, we set the failure link node of every
          // punctuation char as a special node trie_punct_failure_link_node_
          // which is a dummy node (no parent, no descendants, failure link is
          // null). Hence, by detecting the landing node, we know we just
          // matched a punctuation char. We then split it as a single word.
          failure_link = trie_punct_failure_link_node_;
        }
        // Case 1 (easy): str(v) is in V. Assume that during tokenization of a
        // word, we reached node v, but can't continue further, because the
        // current char from the input word does not match any of the edges
        // outgoing from v. In that case, str(v) is already the max match, so
        // it's the only wordpiece we add to the list of wordpieces we committed
        // to. Hence, F(v) = [str(v)]. The next wordpiece from the current word
        // is a suffix, so we move to node f(v) = trie_suffix_root_, which
        // represents the suffix indicator (e.g., "##"), from where we continue
        // the match process. In summary, we have:
        //  * f(v) = trie_suffix_root_.
        //  * F(v) = [str(v)].
        SH_RETURN_IF_ERROR(AssignFailureLinkAndPops(
            /*cur_node=*/child_node.node_id, /*failure_link=*/failure_link,
            /*one_step_pops=*/{child_data_value},
            /*parent_failure_pops_offset_length=*/
            fast_wordpiece_tokenizer_utils::kNullFailurePopsList));
        bfs_queue.push(child_node.node_id);
        continue;
      }

      // Case 2 (complex): str(v) is not in V.
      //
      // Consider the same scenario as in Case 1, where we can't continue
      // further from v, but now, str(v) is not a valid wordpiece. Instead,
      // we need to consider the wordpieces that the MaxMatch algorithm would
      // generate for the beginning of str(v) (these wordpieces are stored in
      // F(v)). f(v) (the state we transit to) should correspond to the trie
      // node for the remaining suffix of str(v).
      //
      // We could compute F(v) and f(v) by running the original WordPiece
      // algorithm. Instead, we do it even faster, by using F(u) and f(u) (the
      // similar info for the parent node u). Intuitively F(v) consists of (1)
      // the tokens from F(u) and (2) the possible tokens that the MaxMatch
      // algorithm would generate for str(f(u)).c, where str(f(u)) is the suffix
      // of str(u) not covered by the concatenation of the tokens from F(u), "."
      // means concatenation, and c is the edge label character from u to v.
      //
      //
      // Let u be the parent node, and c be the edge label from u to v. To
      // compute f(v) and F(v), the loop below uses a node variable z (called
      // `itr_node`) and a list G (called `one_steps_pops`). Initially, z is set
      // to be f(u), and G is empty.
      //  1. If z is null, f(v) will be null, too (see Note 2 below for what
      //  this means). We're done.
      //  2. Check if there is a trie edge out of node z, for label c, leading
      //    to node goto(z, c). If so, set f(v) = goto(z,c) and F(v) = F(u) + G.
      //    We're done and break.
      //  3. Otherwise, collect the pop tokens (by G = G + F(z)) and
      //    follows the failure link (by z = f(z)).
      //  4. Goes to Step 1 and continue the loop.
      //
      // Note 1: processing node v depends on the info for nodes z that are
      // closer to the root than v. Due to our use of the BFS traversal, that
      // info is guaranteed to exist when we examine node v.
      //
      // Note 2: f(v) is null means that during the tokenization process of some
      // input word, if the trie matching cannot continue at node v, there are
      // no failure links that we can follow, and (it can be proved that in such
      // a case) the input word can't be tokenized with the current vocab.
      //
      // For formal discussions and proofs, please refer to the academic paper
      // https://arxiv.org/abs/2012.15524
      const FailureStruct& parent_fs = failure_struct_array_[parent_id];
      if (parent_fs.failure_link != fast_wordpiece_tokenizer_utils::kNullNode) {
        std::vector<int> one_step_pops;
        auto itr_node = trie_->CreateTraversalCursor(parent_fs.failure_link);
        while (true) {
          if (trie_->TryTraverseOneStep(itr_node, edge_label)) {
            // Set the failure link and failure pops for `child_node`.
            SH_RETURN_IF_ERROR(AssignFailureLinkAndPops(
                /*cur_node=*/child_node.node_id,
                /*failure_link=*/itr_node.node_id, one_step_pops,
                parent_fs.failure_pops_offset_length));
            break;
          }
          const FailureStruct& itr_node_fs =
              failure_struct_array_[itr_node.node_id];
          if (itr_node_fs.failure_link ==
              fast_wordpiece_tokenizer_utils::kNullNode) {
            // Cannot follow anymore: failure link of `child_node` will be null.
            break;
          }
          // Append the failure pops of `itr_node` to `one_step_pops`.
          GetFailurePopsAndAppendToOut(itr_node_fs.failure_pops_offset_length,
                                       one_step_pops);
          // Follow the failure link.
          trie_->SetTraversalCursor(itr_node, itr_node_fs.failure_link);
        }
      }

      bfs_queue.push(child_node.node_id);
    }
  }

  if (!no_pretokenization_ && !suffix_indicator_.empty()) {
    // Rewire trie links along suffix_indicator_.
    // If the suffix indicator contains a punctuation char, let `u`--(`c`)-->`v`
    // be the first trie edge along the suffix indicator such that the edge
    // label (i.e. `c`) is a punctuation char. Note that `u`, `v` are trie
    // nodes. `c` is the edge label. We make the following change:
    //
    // Case 1: if `u` is the root, we remove the trie edge from `v` to its child
    // along the suffix indicator.
    // Case 2: if `u` is not the root, we remove the trie edge from `u` to `v`.
    //
    // Example 1: if suffix_indicator_ is "##" (as in BERT), we remove the trie
    // link from "#" to "##".  The goal here is to make sure we match the
    // punctuation character "#" as a token by itself, without matching "##"
    // (as we split by punctuation, "##" is not a valid token).
    // Example 2: if suffix_indicator is "foo#", we remove the trie link from
    // "foo" to "foo#".
    int cur_pos = 0;
    int next_pos = 0;
    bool prev_node_id_is_root = false;
    auto node = trie_->CreateTraversalCursorPointToRoot();
    UChar32 c;
    int suffix_indicator_length = suffix_indicator_.size();
    while (cur_pos < suffix_indicator_length) {
      next_pos = cur_pos;
      U8_NEXT(suffix_indicator_, next_pos, suffix_indicator_length, c);
      prev_node_id_is_root = (node.node_id == trie_->kRootNodeId);
      absl::string_view cur_unicode_char(suffix_indicator_.data() + cur_pos,
                                         next_pos - cur_pos);
      if (!trie_->TryTraverseSeveralSteps(node, cur_unicode_char)) {
        return absl::FailedPreconditionError(
            "Cannot locate a character in suffix_indicator_. It should never "
            "happen.");
      }
      if (fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar(c)) {
        // If the previous node is a root node, read the next char to break the
        //  link from the current punctuation char to its next child node.
        if (prev_node_id_is_root) {
          cur_pos = next_pos;
          U8_FWD_1(suffix_indicator_, next_pos, suffix_indicator_length);
          const absl::string_view next_unicode_char(
              suffix_indicator_.data() + cur_pos, next_pos - cur_pos);
          auto child_node = node;
          if (!trie_->TryTraverseSeveralSteps(child_node, next_unicode_char)) {
            return absl::FailedPreconditionError(
                "Cannot locate a character in suffix_indicator_. It should "
                "never happen.");
          }
          BreakTrieLinkFromParentToChild(child_node.node_id);
        } else {
          BreakTrieLinkFromParentToChild(node.node_id);
        }
        break;
      }
      cur_pos = next_pos;
    }
  }
  return absl::OkStatus();
}