EncoderResult EncodeNormalizedString()

in tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc [146:220]


EncoderResult EncodeNormalizedString(const std::string& str,
                                     const std::vector<int>& offsets,
                                     const EncoderConfig& config, bool add_bos,
                                     bool add_eos, bool reverse) {
  const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
  const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
  const int unknown_code = config.unknown_code();
  const float unknown_penalty = config.unknown_penalty();
  struct LatticeElement {
    float score = 0;
    int code = -1;
    int prev_position = -1;
    LatticeElement(float score_, int code_, int prev_position_)
        : score(score_), code(code_), prev_position(prev_position_) {}
    LatticeElement() {}
  };
  const int length = str.length();
  std::vector<LatticeElement> lattice(length + 1);
  for (int i = 0; i < length; ++i) {
    if (i > 0 && lattice[i].prev_position < 0) {
      // This state is unreachable.
      continue;
    }
    if (unknown_code >= 0) {
      // Put unknown code.
      const float penalized_score = lattice[i].score + unknown_penalty;
      const int pos = i + 1;
      LatticeElement& current_element = lattice[pos];
      if (current_element.prev_position < 0 ||
          current_element.score < penalized_score) {
        current_element = LatticeElement(
            penalized_score, unknown_code,
            // If the current state is already reached by unknown code, merge
            // states.
            lattice[i].code == unknown_code ? lattice[i].prev_position : i);
      }
    }
    auto lattice_update = [&lattice, i,
                           piece_scores](const DoubleArrayTrie::Match& m) {
      LatticeElement& target_element = lattice[i + m.match_length];
      const float score = lattice[i].score + (*piece_scores)[m.id];
      if (target_element.prev_position < 0 || target_element.score < score) {
        target_element = LatticeElement(score, m.id, i);
      }
    };
    piece_matcher.IteratePrefixMatches(
        utils::string_view(str.data() + i, length - i), lattice_update);
  }

  EncoderResult result;
  if (add_eos) {
    result.codes.push_back(config.end_code());
    result.offsets.push_back(length);
  }
  if (lattice[length].prev_position >= 0) {
    for (int pos = length; pos > 0;) {
      auto code = lattice[pos].code;
      if (code != config.unknown_code()) {
        code += config.encoding_offset();
      }
      result.codes.push_back(code);
      pos = lattice[pos].prev_position;
      result.offsets.push_back(offsets[pos]);
    }
  }
  if (add_bos) {
    result.codes.push_back(config.start_code());
    result.offsets.push_back(0);
  }
  if (!reverse) {
    std::reverse(result.codes.begin(), result.codes.end());
    std::reverse(result.offsets.begin(), result.offsets.end());
  }
  return result;
}